[英]Trouble with restoring pretrained model in tensorflow
我運行了TensorFlow中包含的word2vec演示程序,現在嘗試從文件中恢復經過預訓練的模型,但是它不起作用。
我運行了這個腳本文件: https : //github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/embedding/word2vec.py
然后,我嘗試運行此文件:
#!/usr/bin/env python
import tensorflow as tf
FILENAME_META = "model.ckpt-70707299.meta"
FILENAME_CHECKPOINT = "model.ckpt-70707299"
def main():
with tf.Session() as sess:
saver = tf.train.import_meta_graph(FILENAME_META)
saver.restore(sess, FILENAME_CHECKPOINT)
if __name__ == '__main__':
main()
它失敗,並顯示以下錯誤消息
Traceback (most recent call last):
File "word2vec_restore.py", line 16, in <module>
main()
File "word2vec_restore.py", line 11, in main
saver = tf.train.import_meta_graph(FILENAME_META)
File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1431, in import_meta_graph
return _import_meta_graph_def(read_meta_graph_file(meta_graph_or_file))
File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1321, in _import_meta_graph_def
producer_op_list=producer_op_list)
File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 247, in import_graph_def
op_def = op_dict[node.op]
KeyError: 'Skipgram'
我認為我已經了解了TensorFlow的API文檔,並且按照上面編寫的代碼實現了上面的代碼。 我是否以錯誤的方式使用了Saver對象?
我自己解決了。 我想知道密鑰“ Skipgram”從何而來,並挖出了源代碼。 要解決此問題,只需在頂部添加以下內容:
from tensorflow.models.embedding import gen_word2vec
我仍然不清楚我在做什么,但這也許是因為有必要加載用C ++編寫的相關庫。
謝謝。
請嘗試以下操作:
saver = tf.train.Saver()
with tf.Session() as sess:
checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
if checkpoint and checkpoint.model_checkpoint_path:
saver.restore(sess, checkpoint.model_checkpoint_path)
其中checkpoint_dir
是包含檢查點文件的文件夾的路徑,而不是meta或檢查點文件的完整路徑。 Tensorflow從指定的文件夾中選擇最新的檢查點本身。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.