简体   繁体   中英

How do I set TensorFlow RNN state when state_is_tuple=True?

I have written an RNN language model using TensorFlow . The model is implemented as an RNN class. The graph structure is built in the constructor, while RNN.train and RNN.test methods run it.

I want to be able to reset the RNN state when I move to a new document in the training set, or when I want to run a validation set during training. I do this by managing the state inside the training loop, passing it into the graph via a feed dictionary.

In the constructor I define the the RNN like so

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
                                                  initial_state=self.state)

The training loop looks like this

 for document in document:
     state = session.run(self.reset_state)
     for x, y in document:
          _, state = session.run([self.train_step, self.next_state], 
                                 feed_dict={self.x:x, self.y:y, self.state:state})

x and y are batches of training data in a document. The idea is that I pass the latest state along after each batch, except when I start a new document, when I zero out the state by running self.reset_state .

This all works. Now I want to change my RNN to use the recommended state_is_tuple=True . However, I don't know how to pass the more complicated LSTM state object via a feed dictionary. Also I don't know what arguments to pass to the self.state = tf.placeholder(...) line in my constructor.

What is the correct strategy here? There still isn't much example code or documentation for dynamic_rnn available.


TensorFlow issues 2695 and 2838 appear relevant.

A blog post on WILDML addresses these issues but doesn't directly spell out the answer.

See also TensorFlow: Remember LSTM state for next batch (stateful LSTM) .

One problem with a Tensorflow placeholder is that you can only feed it with a Python list or Numpy array (I think). So you can't save the state between runs in tuples of LSTMStateTuple.

I solved this by saving the state in a tensor like this

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

You have two components in an LSTM layer, the cell state and hidden state , thats what the "2" comes from. (this article is great: https://arxiv.org/pdf/1506.00019.pdf )

When building the graph you unpack and create the tuple state like this:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

Then you get the new state the usual way

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

It shouldn't be like this... perhaps they are working on a solution.

A simple way to feed in an RNN state is to simply feed in both components of the state tuple individually.

# Constructing the graph
self.state = rnn_cell.zero_state(...)
self.output, self.next_state = tf.nn.dynamic_rnn(
    rnn_cell,
    self.input,
    initial_state=self.state)

# Running with initial state
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input
})

# Running with subsequent state:
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input,
    self.state[0]: state[0],
    self.state[1]: state[1]
})

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM