[英]Trouble with restoring pretrained model in tensorflow
I ran the demo program of word2vec which is included in TensorFlow, and now trying to restore the pretrained model from files, but it doesn't work. 我运行了TensorFlow中包含的word2vec演示程序,现在尝试从文件中恢复经过预训练的模型,但是它不起作用。
I ran this script file: https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/embedding/word2vec.py 我运行了这个脚本文件: https : //github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/embedding/word2vec.py
Then I tried to run this file: 然后,我尝试运行此文件:
#!/usr/bin/env python
import tensorflow as tf
FILENAME_META = "model.ckpt-70707299.meta"
FILENAME_CHECKPOINT = "model.ckpt-70707299"
def main():
with tf.Session() as sess:
saver = tf.train.import_meta_graph(FILENAME_META)
saver.restore(sess, FILENAME_CHECKPOINT)
if __name__ == '__main__':
main()
It fails with the following error message 它失败,并显示以下错误消息
Traceback (most recent call last):
File "word2vec_restore.py", line 16, in <module>
main()
File "word2vec_restore.py", line 11, in main
saver = tf.train.import_meta_graph(FILENAME_META)
File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1431, in import_meta_graph
return _import_meta_graph_def(read_meta_graph_file(meta_graph_or_file))
File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1321, in _import_meta_graph_def
producer_op_list=producer_op_list)
File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 247, in import_graph_def
op_def = op_dict[node.op]
KeyError: 'Skipgram'
I regard I have understood the API document of TensorFlow, and I implemented the code above as is written in it. 我认为我已经了解了TensorFlow的API文档,并且按照上面编写的代码实现了上面的代码。 Am I using the Saver object in a wrong way? 我是否以错误的方式使用了Saver对象?
I solved this by myself. 我自己解决了。 I wondered where the key 'Skipgram' comes from, and dug the source code. 我想知道密钥“ Skipgram”从何而来,并挖出了源代码。 To solve the problem, just add the following on the top: 要解决此问题,只需在顶部添加以下内容:
from tensorflow.models.embedding import gen_word2vec
I still don't understand exactly what I am doing, but maybe this is because it is necessary to load a related library written in C++. 我仍然不清楚我在做什么,但这也许是因为有必要加载用C ++编写的相关库。
Thanks. 谢谢。
Try the following: 请尝试以下操作:
saver = tf.train.Saver()
with tf.Session() as sess:
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
if checkpoint and checkpoint.model_checkpoint_path:
saver.restore(sess, checkpoint.model_checkpoint_path)
Where checkpoint_dir
is path to folder that contains checkpoint files, not full path to meta or checkpoint files. 其中checkpoint_dir
是包含检查点文件的文件夹的路径,而不是meta或检查点文件的完整路径。 Tensorflow picks the latest checkpoint itself from the specified folder. Tensorflow从指定的文件夹中选择最新的检查点本身。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.