簡體   English   中英

在張量流中恢復預訓練模型的問題

[英]Trouble with restoring pretrained model in tensorflow

我運行了TensorFlow中包含的word2vec演示程序,現在嘗試從文件中恢復經過預訓練的模型,但是它不起作用。

我運行了這個腳本文件: https : //github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/embedding/word2vec.py

然后,我嘗試運行此文件:

#!/usr/bin/env python

import tensorflow as tf

FILENAME_META = "model.ckpt-70707299.meta"
FILENAME_CHECKPOINT = "model.ckpt-70707299"


def main():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(FILENAME_META)
        saver.restore(sess, FILENAME_CHECKPOINT)


if __name__ == '__main__':
    main()

它失敗,並顯示以下錯誤消息

Traceback (most recent call last):
  File "word2vec_restore.py", line 16, in <module>
    main()
  File "word2vec_restore.py", line 11, in main
    saver = tf.train.import_meta_graph(FILENAME_META)
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1431, in import_meta_graph
    return _import_meta_graph_def(read_meta_graph_file(meta_graph_or_file))
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1321, in _import_meta_graph_def
    producer_op_list=producer_op_list)
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 247, in import_graph_def
    op_def = op_dict[node.op]
KeyError: 'Skipgram'

我認為我已經了解了TensorFlow的API文檔,並且按照上面編寫的代碼實現了上面的代碼。 我是否以錯誤的方式使用了Saver對象?

我自己解決了。 我想知道密鑰“ Skipgram”從何而來,並挖出了源代碼。 要解決此問題,只需在頂部添加以下內容:

from tensorflow.models.embedding import gen_word2vec

我仍然不清楚我在做什么,但這也許是因為有必要加載用C ++編寫的相關庫。

謝謝。

請嘗試以下操作:

saver = tf.train.Saver()
with tf.Session() as sess:
    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
    if checkpoint and checkpoint.model_checkpoint_path:
        saver.restore(sess, checkpoint.model_checkpoint_path)

其中checkpoint_dir是包含檢查點文件的文件夾的路徑,而不是meta或檢查點文件的完整路徑。 Tensorflow從指定的文件夾中選擇最新的檢查點本身。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM