简体   繁体   English

如何从张量流中的RNN模型中提取单元状态和隐藏状态?

[英]How to extract the cell state and hidden state from an RNN model in tensorflow?

I am new to TensorFlow and have difficulties understanding the RNN module. 我是TensorFlow的新手,很难理解RNN模块。 I am trying to extract hidden/cell states from an LSTM. 我试图从LSTM中提取隐藏/单元状态。 For my code, I am using the implementation from https://github.com/aymericdamien/TensorFlow-Examples . 对于我的代码,我使用的是https://github.com/aymericdamien/TensorFlow-Examples中的实现。

# tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])

# Define weights
weights = {'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))}
biases = {'out': tf.Variable(tf.random_normal([n_classes]))}

def RNN(x, weights, biases):
    # Prepare data shape to match `rnn` function requirements
    # Current data input shape: (batch_size, n_steps, n_input)
    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)

    # Permuting batch_size and n_steps
    x = tf.transpose(x, [1, 0, 2])
    # Reshaping to (n_steps*batch_size, n_input)
    x = tf.reshape(x, [-1, n_input])
    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
    x = tf.split(0, n_steps, x)

    # Define a lstm cell with tensorflow
    #with tf.variable_scope('RNN'):
    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0, state_is_tuple=True)

    # Get lstm cell output
        outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

    # Linear activation, using rnn inner loop last output
    return tf.matmul(outputs[-1], weights['out']) + biases['out'], states

pred, states = RNN(x, weights, biases)

# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

# Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Initializing the variables
init = tf.initialize_all_variables()

Now I want to extract the cell/hidden state for each time step in a prediction. 现在我想为预测中的每个时间步骤提取单元格/隐藏状态。 The state is stored in a LSTMStateTuple of the form (c,h), which I can find out by evaluating print states . 状态存储在形式为(c,h)的LSTMStateTuple中,我可以通过评估print states找到它。 However, trying to call print states.c.eval() (which according to the documentation should give me values in the tensor states.c ), yields an error stating that my variables are not initialized even though I am calling it right after I am predicting something. 但是,尝试调用print states.c.eval() (根据文档应该给出张量states.c值),会产生一个错误,指出我的变量没有初始化,即使我在我之后调用它我在预测一些事情。 The code for this is here: 这个代码在这里:

# Launch the graph
with tf.Session() as sess:
    sess.run(init)
    step = 1
    # Keep training until reach max iterations
    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='RNN'):
        print v.name
    while step * batch_size < training_iters:
        batch_x, batch_y = mnist.train.next_batch(batch_size)
        # Reshape data to get 28 seq of 28 elements
        batch_x = batch_x.reshape((batch_size, n_steps, n_input))
        # Run optimization op (backprop)
        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y})

        print states.c.eval()
        # Calculate batch accuracy
        acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y})

        step += 1
    print "Optimization Finished!"

and the error message is 并且错误消息是

InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float
     [[Node: Placeholder = Placeholder[dtype=DT_FLOAT, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

The states are also not visible in tf.all_variables() , only the trained matrix/bias tensors (as described here: Tensorflow: show or save forget gate values in LSTM ). 状态在tf.all_variables()中也是不可见的,只有经过训练的矩阵/偏置张量(如此处所述: Tensorflow:在LSTM中显示或保存忘记门值 )。 I don't want to build the whole LSTM from scratch though since I have the states in the states variable, I just need to call it. 我不想从头开始构建整个LSTM,因为我在states变量中有states ,我只需要调用它。

You may simply collect the values of the states in the same way accuracy is collected. 您可以以收集精度的相同方式简单地收集states值。

I guess, pred, states, acc = sess.run(pred, states, accuracy, feed_dict={x: batch_x, y: batch_y}) should work perfectly fine. 我猜, pred, states, acc = sess.run(pred, states, accuracy, feed_dict={x: batch_x, y: batch_y})应该可以正常工作。

One comment about your assumption: the "states" does have only the values of "hidden state" and "memory cell" from last timestep. 关于你的假设的一个评论:“状态”确实只具有上次时间步的“隐藏状态”和“存储单元”的值。

The "outputs" contain the "hidden state" from each time step you want (the size of outputs is [batch_size, seq_len, hidden_size]. So I assume that you want "outputs" variable, not "states". See the documentation . “输出”包含你想要的每个时间步的“隐藏状态”(输出的大小是[batch_size,seq_len,hidden_​​size]。所以我假设你想要“输出”变量,而不是“状态”。参见文档

I have to disagree with the answer of user3480922. 我不同意user3480922的回答。 For the code: 对于代码:

outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)

to be able to extract the hidden state for each time_step in a prediction, you have to use the outputs. 为了能够为预测中的每个time_step提取隐藏状态,您必须使用输出。 Because outputs have the hidden state value for each time_step. 因为输出具有每个time_step的隐藏状态值。 However, I am not sure is there any way we can store the values of the cell state for each time_step as well. 但是,我不确定是否有任何方法可以存储每个time_step的单元格状态值。 Because states tuple provides the cell state values but only for the last time_step. 因为状态元组提供单元状态值,但仅适用于最后一个time_step。

For example, in the following sample with 5 time_steps, the outputs[4,:,:], time_step = 0,...,4 has the hidden state values for time_step=4, whereas the states tuple h only has the hidden state values for time_step=4. 例如,在以下带有5个time_steps的示例中,输出[4,:,],time_step = 0,...,4具有time_step = 4的隐藏状态值,而状态元组h仅具有隐藏状态time_step = 4的值。 State tuple c has the cell value at the time_step=4 though. 状态元组c具有time_step = 4处的单元格值。

  outputs = [[[ 0.0589103 -0.06925126 -0.01531546 0.06108122]
  [ 0.00861215 0.06067181 0.03790079 -0.04296958]
  [ 0.00597713 0.03916606 0.02355802 -0.0277683 ]]

  [[ 0.06252582 -0.07336216 -0.01607122 0.05024602]
  [ 0.05464711 0.03219429 0.06635305 0.00753127]
  [ 0.05385715 0.01259535 0.0524035 0.01696803]]

  [[ 0.0853352 -0.06414541 0.02524283 0.05798233]
  [ 0.10790729 -0.05008117 0.03003334 0.07391824]
  [ 0.10205664 -0.04479517 0.03844892 0.0693808 ]]

  [[ 0.10556188 0.0516542 0.09162509 -0.02726674]
  [ 0.11425048 -0.00211394 0.06025286 0.03575509]
  [ 0.11338984 0.02839304 0.08105748 0.01564003]]

  **[[ 0.10072514 0.14767936 0.12387902 -0.07391471]
  [ 0.10510238 0.06321315 0.08100517 -0.00940042]
  [ 0.10553667 0.0984127 0.10094948 -0.02546882]]**]
  states = LSTMStateTuple(c=array([[ 0.23870754, 0.24315512, 0.20842518, -0.12798975],
  [ 0.23749796, 0.10797793, 0.14181322, -0.01695861],
  [ 0.2413336 , 0.16692916, 0.17559692, -0.0453596 ]], dtype=float32), h=array(**[[ 0.10072514, 0.14767936, 0.12387902, -0.07391471],
  [ 0.10510238, 0.06321315, 0.08100517, -0.00940042],
  [ 0.10553667, 0.0984127 , 0.10094948, -0.02546882]]**, dtype=float32))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM