简体   繁体   English

Tensorflow:从还原的RNN撤回隐藏状态

[英]Tensorflow : Retreive hidden state from restored RNN

I'd like to restore an RNN and get the hidden state. 我想还原一个RNN并获得隐藏状态。

I do something like that to save the RNN: 我这样做是为了保存RNN:

loc="path/to/save/rnn"
with tf.variable_scope("lstm") as scope:
    outputs, state = tf.nn.dynamic_rnn(..)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
save_path = saver.save(sess,loc)

Now I want to retreive state . 现在我要撤退state

graph = tf.Graph()
sess = tf.Session(graph=graph)
with graph.as_default():
      saver = tf.train.import_meta_graph(loc + '.meta', clear_devices=True)
      saver.restore(sess, loc)
      state= ...

You can add the state tensor to a graph collection , which is basically a key value store to track tensors, using tf.add_to_collection and retrieve it later using tf.get_collection . 您可以使用tf.add_to_collectionstate张量添加到图集合中 ,该集合基本上是用于跟踪张量的键值存储,以后再使用tf.get_collection检索它。 For example: 例如:

loc="path/to/save/rnn"
with tf.variable_scope("lstm") as scope:
    outputs, state = tf.nn.dynamic_rnn(..)
    tf.add_to_collection('state', state)


graph = tf.Graph()
with graph.as_default():
      saver = tf.train.import_meta_graph(loc + '.meta', clear_devices=True)
      state = tf.get_collection('state')[0]  # Note: tf.get_collection returns a list.

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM