简体   繁体   中英

Trouble with restoring pretrained model in tensorflow

I ran the demo program of word2vec which is included in TensorFlow, and now trying to restore the pretrained model from files, but it doesn't work.

I ran this script file: https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/embedding/word2vec.py

Then I tried to run this file:

#!/usr/bin/env python

import tensorflow as tf

FILENAME_META = "model.ckpt-70707299.meta"
FILENAME_CHECKPOINT = "model.ckpt-70707299"


def main():
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(FILENAME_META)
        saver.restore(sess, FILENAME_CHECKPOINT)


if __name__ == '__main__':
    main()

It fails with the following error message

Traceback (most recent call last):
  File "word2vec_restore.py", line 16, in <module>
    main()
  File "word2vec_restore.py", line 11, in main
    saver = tf.train.import_meta_graph(FILENAME_META)
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1431, in import_meta_graph
    return _import_meta_graph_def(read_meta_graph_file(meta_graph_or_file))
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1321, in _import_meta_graph_def
    producer_op_list=producer_op_list)
  File "/home/kato/.pyenv/versions/3.5.1/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 247, in import_graph_def
    op_def = op_dict[node.op]
KeyError: 'Skipgram'

I regard I have understood the API document of TensorFlow, and I implemented the code above as is written in it. Am I using the Saver object in a wrong way?

I solved this by myself. I wondered where the key 'Skipgram' comes from, and dug the source code. To solve the problem, just add the following on the top:

from tensorflow.models.embedding import gen_word2vec

I still don't understand exactly what I am doing, but maybe this is because it is necessary to load a related library written in C++.

Thanks.

Try the following:

saver = tf.train.Saver()
with tf.Session() as sess:
    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir)
    if checkpoint and checkpoint.model_checkpoint_path:
        saver.restore(sess, checkpoint.model_checkpoint_path)

Where checkpoint_dir is path to folder that contains checkpoint files, not full path to meta or checkpoint files. Tensorflow picks the latest checkpoint itself from the specified folder.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM