简体   繁体   English

当state_is_tuple = True时,如何设置TensorFlow RNN状态?

[英]How do I set TensorFlow RNN state when state_is_tuple=True?

I have written an RNN language model using TensorFlow . 使用TensorFlow编写了一个RNN语言模型 The model is implemented as an RNN class. 该模型实现为RNN类。 The graph structure is built in the constructor, while RNN.train and RNN.test methods run it. 图结构是在构造函数中构建的,而RNN.trainRNN.test方法则运行它。

I want to be able to reset the RNN state when I move to a new document in the training set, or when I want to run a validation set during training. 我想在移动到训练集中的新文档时,或者当我想在训练期间运行验证集时,能够重置RNN状态。 I do this by managing the state inside the training loop, passing it into the graph via a feed dictionary. 我通过管理训练循环内的状态,通过提要字典将其传递到图表中来实现此目的。

In the constructor I define the the RNN like so 在构造函数中,我像这样定义RNN

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
                                                  initial_state=self.state)

The training loop looks like this 训练循环看起来像这样

 for document in document:
     state = session.run(self.reset_state)
     for x, y in document:
          _, state = session.run([self.train_step, self.next_state], 
                                 feed_dict={self.x:x, self.y:y, self.state:state})

x and y are batches of training data in a document. xy是文档中的批量训练数据。 The idea is that I pass the latest state along after each batch, except when I start a new document, when I zero out the state by running self.reset_state . 我的想法是,每次批处理后都会传递最新的状态,除非我启动一个新文档,当我通过运行self.reset_state将状态归零时。

This all works. 这一切都有效。 Now I want to change my RNN to use the recommended state_is_tuple=True . 现在我想更改我的RNN以使用推荐的state_is_tuple=True However, I don't know how to pass the more complicated LSTM state object via a feed dictionary. 但是,我不知道如何通过提要字典传递更复杂的LSTM状态对象。 Also I don't know what arguments to pass to the self.state = tf.placeholder(...) line in my constructor. 另外我不知道在self.state = tf.placeholder(...)函数中传递给self.state = tf.placeholder(...)行的参数。

What is the correct strategy here? 这里的正确策略是什么? There still isn't much example code or documentation for dynamic_rnn available. 可用的dynamic_rnn仍然没有太多示例代码或文档。


TensorFlow issues 2695 and 2838 appear relevant. TensorFlow问题26952838似乎相关。

A blog post on WILDML addresses these issues but doesn't directly spell out the answer. 关于WILDML的博客文章解决了这些问题,但没有直接说明答案。

See also TensorFlow: Remember LSTM state for next batch (stateful LSTM) . 另请参见TensorFlow:记住下一批次的LSTM状态(有状态LSTM)

One problem with a Tensorflow placeholder is that you can only feed it with a Python list or Numpy array (I think). Tensorflow占位符的一个问题是你只能用Python列表或Numpy数组(我认为)来提供它。 So you can't save the state between runs in tuples of LSTMStateTuple. 因此,您无法在LSTMStateTuple的元组中的运行之间保存状态。

I solved this by saving the state in a tensor like this 我通过将状态保存在这样的张量中来解决这个问题

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

You have two components in an LSTM layer, the cell state and hidden state , thats what the "2" comes from. LSTM层中有两个组件,即单元状态隐藏状态 ,这就是“2”的来源。 (this article is great: https://arxiv.org/pdf/1506.00019.pdf ) (这篇文章很棒: https//arxiv.org/pdf/1506.00019.pdf

When building the graph you unpack and create the tuple state like this: 构建图形时,解压缩并创建元组状态,如下所示:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

Then you get the new state the usual way 然后你通常的方式得到新的状态

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

It shouldn't be like this... perhaps they are working on a solution. 它应该不是这样的......也许他们正在努力解决问题。

A simple way to feed in an RNN state is to simply feed in both components of the state tuple individually. 在RNN状态下馈送的简单方法是单独地馈送状态元组的两个分量。

# Constructing the graph
self.state = rnn_cell.zero_state(...)
self.output, self.next_state = tf.nn.dynamic_rnn(
    rnn_cell,
    self.input,
    initial_state=self.state)

# Running with initial state
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input
})

# Running with subsequent state:
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input,
    self.state[0]: state[0],
    self.state[1]: state[1]
})

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM