[英]How to use trained BERT model checkpoints for prediction?
我使用 SQUAD 2.0 训练了 BERT 并使用BERT-master在输出目录中获得了model.ckpt.data
、 model.ckpt.meta
、 model.ckpt.index
(F1 分数:81)以及predictions.json
等/run_squad.py
python run_squad.py \
--vocab_file=$BERT_LARGE_DIR/vocab.txt \
--bert_config_file=$BERT_LARGE_DIR/bert_config.json \
--init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt \
--do_train=True \
--train_file=$SQUAD_DIR/train-v2.0.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v2.0.json \
--train_batch_size=24 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=gs://some_bucket/squad_large/ \
--use_tpu=True \
--tpu_name=$TPU_NAME \
--version_2_with_negative=True
我尝试将model.ckpt.meta
、 model.ckpt.index
、 model.ckpt.data
到$BERT_LARGE_DIR
目录并更改run_squad.py
标志如下仅预测答案而不使用数据集进行训练:
python run_squad.py \
--vocab_file=$BERT_LARGE_DIR/vocab.txt \
--bert_config_file=$BERT_LARGE_DIR/bert_config.json \
--init_checkpoint=$BERT_LARGE_DIR/model.ckpt \
--do_train=False \
--train_file=$SQUAD_DIR/train-v2.0.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v2.0.json \
--train_batch_size=24 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=gs://some_bucket/squad_large/ \
--use_tpu=True \
--tpu_name=$TPU_NAME \
--version_2_with_negative=True
它抛出bucket directory/model.ckpt不存在错误。
如何利用训练后生成的检查点并将其用于预测?
通常,训练的检查点是在训练时在--output_dir
参数指定的目录中--output_dir
。 (在您的情况下是gs://some_bucket/squad_large/
)。 每个检查点都会有一个编号。 你必须找出最大的数字; 例如: model.ckpt-12345
。 现在,使用输出目录和最后保存的检查点(编号最高的模型)在您的评估/预测中设置--init_checkpoint
参数。 (在您的情况下,它应该类似于--init_checkpoint=gs://some_bucket/squad_large/model.ckpt-<highest number>
)
在第二个代码中 FLAG init_checkpoint
我认为它应该是:
--init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt
与上面的一样,而不是--init_checkpoint=$BERT_LARGE_DIR/model.ckpt
。
如果问题仍然存在,您是否使用了multi_cased_L-12_H-768_A-12
预训练模型?
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.