简体   繁体   English

在张量流中恢复预训练模型的问题

[英]Trouble with restoring pretrained model in tensorflow

I ran the demo program of word2vec which is included in TensorFlow, and now trying to restore the pretrained model from files, but it doesn't work. 我运行了TensorFlow中包含的word2vec演示程序,现在尝试从文件中恢复经过预训练的模型,但是它不起作用。

I ran this script file: https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/embedding/word2vec.py 我运行了这个脚本文件: https : //github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/embedding/word2vec.py

Then I tried to run this file: 然后,我尝试运行此文件:

#!/usr/bin/env python

import tensorflow as tf

FILENAME_META = "model.ckpt-70707299.meta"
FILENAME_CHECKPOINT = "model.ckpt-70707299"


def main():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(FILENAME_META)
        saver.restore(sess, FILENAME_CHECKPOINT)


if __name__ == '__main__':
    main()

It fails with the following error message 它失败,并显示以下错误消息

Traceback (most recent call last):
  File "word2vec_restore.py", line 16, in <module>
    main()
  File "word2vec_restore.py", line 11, in main
    saver = tf.train.import_meta_graph(FILENAME_META)
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1431, in import_meta_graph
    return _import_meta_graph_def(read_meta_graph_file(meta_graph_or_file))
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1321, in _import_meta_graph_def
    producer_op_list=producer_op_list)
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 247, in import_graph_def
    op_def = op_dict[node.op]
KeyError: 'Skipgram'

I regard I have understood the API document of TensorFlow, and I implemented the code above as is written in it. 我认为我已经了解了TensorFlow的API文档,并且按照上面编写的代码实现了上面的代码。 Am I using the Saver object in a wrong way? 我是否以错误的方式使用了Saver对象?

I solved this by myself. 我自己解决了。 I wondered where the key 'Skipgram' comes from, and dug the source code. 我想知道密钥“ Skipgram”从何而来,并挖出了源代码。 To solve the problem, just add the following on the top: 要解决此问题,只需在顶部添加以下内容:

from tensorflow.models.embedding import gen_word2vec

I still don't understand exactly what I am doing, but maybe this is because it is necessary to load a related library written in C++. 我仍然不清楚我在做什么,但这也许是因为有必要加载用C ++编写的相关库。

Thanks. 谢谢。

Try the following: 请尝试以下操作:

saver = tf.train.Saver()
with tf.Session() as sess:
    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
    if checkpoint and checkpoint.model_checkpoint_path:
        saver.restore(sess, checkpoint.model_checkpoint_path)

Where checkpoint_dir is path to folder that contains checkpoint files, not full path to meta or checkpoint files. 其中checkpoint_dir是包含检查点文件的文件夹的路径,而不是meta或检查点文件的完整路径。 Tensorflow picks the latest checkpoint itself from the specified folder. Tensorflow从指定的文件夹中选择最新的检查点本身。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM