繁体   English   中英

Tensorflow,如何访问RNN的所有中间状态,而不仅仅是最后一个状态

[英]Tensorflow, how to access all the middle states of an RNN, not just the last state

我的理解是tf.nn.dynamic_rnn在每个时间步和最终状态返回RNN单元(例如LSTM)的输出。 如何在所有时间步骤中访问单元格状态而不仅仅是最后一个? 例如,我希望能够平均所有隐藏状态,然后在后续层中使用它。

以下是我如何定义LSTM单元格,然后使用tf.nn.dynamic_rnn展开它。 但这只给出了LSTM的最后一个单元状态。

import tensorflow as tf
import numpy as np

# [batch-size, sequence-length, dimensions] 
X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 6]

cell = tf.contrib.rnn.LSTMCell(num_units=64, state_is_tuple=True)

outputs, last_state = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
out, last = sess.run([outputs, last_state], feed_dict=None)

这样的事情应该有效。

import tensorflow as tf
import numpy as np


class CustomRNN(tf.contrib.rnn.LSTMCell):
    def __init__(self, *args, **kwargs):
        kwargs['state_is_tuple'] = False # force the use of a concatenated state.
        returns = super(CustomRNN, self).__init__(*args, **kwargs) # create an lstm cell
        self._output_size = self._state_size # change the output size to the state size
        return returns
    def __call__(self, inputs, state):
        output, next_state = super(CustomRNN, self).__call__(inputs, state)
        return next_state, next_state # return two copies of the state, instead of the output and the state

X = np.random.randn(2, 10, 8)
X[1,6:] = 0
X_lengths = [10, 10]

cell = CustomRNN(num_units=64)

outputs, last_states = tf.nn.dynamic_rnn(
    cell=cell,
    dtype=tf.float64,
    sequence_length=X_lengths,
    inputs=X)

sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())                                 
states, last_state = sess.run([outputs, last_states], feed_dict=None)

这使用连接状态,因为我不知道你是否可以存储任意数量的元组状态。 states变量具有形状(batch_size,max_time_size,state_size)。

我会指出你这个帖子 (我的亮点):

如果每个时间步都需要c和h状态,则可以编写LSTMCell的变体,它将两个状态张量作为输出的一部分返回。 如果你只需要h状态 ,那就是每个时间步输出

作为@jasekp在其评论中写道,输出的是真正的h的状态的一部分。 然后dynamic_rnn方法只会堆叠中的所有h部分跨时(见的字符串DOC _dynamic_rnn_loop此文件中 ):

def _dynamic_rnn_loop(cell,
                      inputs,
                      initial_state,
                      parallel_iterations,
                      swap_memory,
                      sequence_length=None,
                      dtype=None):
  """Internal implementation of Dynamic RNN.
    [...]
    Returns:
    Tuple `(final_outputs, final_state)`.
    final_outputs:
      A `Tensor` of shape `[time, batch_size, cell.output_size]`.  If
      `cell.output_size` is a (possibly nested) tuple of ints or `TensorShape`
      objects, then this returns a (possibly nsted) tuple of Tensors matching
      the corresponding shapes.

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM