简体   繁体   中英

How to extract the cell state and hidden state from an RNN model in tensorflow?

I am new to TensorFlow and have difficulties understanding the RNN module. I am trying to extract hidden/cell states from an LSTM. For my code, I am using the implementation from https://github.com/aymericdamien/TensorFlow-Examples .

# tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])

# Define weights
weights = {'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))}
biases = {'out': tf.Variable(tf.random_normal([n_classes]))}

def RNN(x, weights, biases):
    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, n_steps, n_input)
    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)

    # Permuting batch_size and n_steps
    x = tf.transpose(x, [1, 0, 2])
    # Reshaping to (n_steps*batch_size, n_input)
    x = tf.reshape(x, [-1, n_input])
    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    x = tf.split(0, n_steps, x)

    # Define a lstm cell with tensorflow
    #with tf.variable_scope('RNN'):
    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True)

    # Get lstm cell output
        outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    return tf.matmul(outputs[-1], weights['out']) + biases['out'], states

pred, states = RNN(x, weights, biases)

# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

# Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Initializing the variables
init = tf.initialize_all_variables()

Now I want to extract the cell/hidden state for each time step in a prediction. The state is stored in a LSTMStateTuple of the form (c,h), which I can find out by evaluating print states . However, trying to call print states.c.eval() (which according to the documentation should give me values in the tensor states.c ), yields an error stating that my variables are not initialized even though I am calling it right after I am predicting something. The code for this is here:

# Launch the graph
with tf.Session() as sess:
    sess.run(init)
    step = 1
    # Keep training until reach max iterations
    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='RNN'):
        print v.name
    while step * batch_size < training_iters:
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # Reshape data to get 28 seq of 28 elements
        batch_x = batch_x.reshape((batch_size, n_steps, n_input))
        # Run optimization op (backprop)
        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})

        print states.c.eval()
        # Calculate batch accuracy
        acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})

        step += 1
    print "Optimization Finished!"

and the error message is

InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float
     [[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

The states are also not visible in tf.all_variables() , only the trained matrix/bias tensors (as described here: Tensorflow: show or save forget gate values in LSTM ). I don't want to build the whole LSTM from scratch though since I have the states in the states variable, I just need to call it.

You may simply collect the values of the states in the same way accuracy is collected.

I guess, pred, states, acc = sess.run(pred, states, accuracy, feed_dict={x: batch_x, y: batch_y}) should work perfectly fine.

One comment about your assumption: the "states" does have only the values of "hidden state" and "memory cell" from last timestep.

The "outputs" contain the "hidden state" from each time step you want (the size of outputs is [batch_size, seq_len, hidden_size]. So I assume that you want "outputs" variable, not "states". See the documentation .

I have to disagree with the answer of user3480922. For the code:

outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

to be able to extract the hidden state for each time_step in a prediction, you have to use the outputs. Because outputs have the hidden state value for each time_step. However, I am not sure is there any way we can store the values of the cell state for each time_step as well. Because states tuple provides the cell state values but only for the last time_step.

For example, in the following sample with 5 time_steps, the outputs[4,:,:], time_step = 0,...,4 has the hidden state values for time_step=4, whereas the states tuple h only has the hidden state values for time_step=4. State tuple c has the cell value at the time_step=4 though.

  outputs = [[[ 0.0589103 -0.06925126 -0.01531546 0.06108122]
  [ 0.00861215 0.06067181 0.03790079 -0.04296958]
  [ 0.00597713 0.03916606 0.02355802 -0.0277683 ]]

  [[ 0.06252582 -0.07336216 -0.01607122 0.05024602]
  [ 0.05464711 0.03219429 0.06635305 0.00753127]
  [ 0.05385715 0.01259535 0.0524035 0.01696803]]

  [[ 0.0853352 -0.06414541 0.02524283 0.05798233]
  [ 0.10790729 -0.05008117 0.03003334 0.07391824]
  [ 0.10205664 -0.04479517 0.03844892 0.0693808 ]]

  [[ 0.10556188 0.0516542 0.09162509 -0.02726674]
  [ 0.11425048 -0.00211394 0.06025286 0.03575509]
  [ 0.11338984 0.02839304 0.08105748 0.01564003]]

  **[[ 0.10072514 0.14767936 0.12387902 -0.07391471]
  [ 0.10510238 0.06321315 0.08100517 -0.00940042]
  [ 0.10553667 0.0984127 0.10094948 -0.02546882]]**]
  states = LSTMStateTuple(c=array([[ 0.23870754, 0.24315512, 0.20842518, -0.12798975],
  [ 0.23749796, 0.10797793, 0.14181322, -0.01695861],
  [ 0.2413336 , 0.16692916, 0.17559692, -0.0453596 ]], dtype=float32), h=array(**[[ 0.10072514, 0.14767936, 0.12387902, -0.07391471],
  [ 0.10510238, 0.06321315, 0.08100517, -0.00940042],
  [ 0.10553667, 0.0984127 , 0.10094948, -0.02546882]]**, dtype=float32))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM