简体   繁体   English

Tensorflow RNN-LSTM-重置隐藏状态

[英]Tensorflow RNN-LSTM - reset hidden state

I'm building a statefull LSTM used for language recognition. 我正在构建用于语言识别的全状态LSTM。 Being statefull I can train the network with smaller files and a new batch will be like a next sentence in a discussion. 处于全状态时,我可以使用较小的文件来训练网络,而新一批将像讨论中的下一句话一样。 However for the network to be properly trained I need to reset the hidden state of the LSTM between some batches. 但是,为了对网络进行适当的培训,我需要在某些批次之间重置LSTM的隐藏状态。

I'm using a variable to store the hidden_state of the LSTM for performance : 我正在使用一个变量来存储LSTM的hidden_​​state以提高性能:

    with tf.variable_scope('Hidden_state'):
        hidden_state = tf.get_variable("hidden_state", [self.num_layers, 2, self.batch_size, self.hidden_size],
                                       tf.float32, initializer=tf.constant_initializer(0.0), trainable=False)
        # Arrange it to a tuple of LSTMStateTuple as needed
        l = tf.unstack(hidden_state, axis=0)
        rnn_tuple_state = tuple([tf.contrib.rnn.LSTMStateTuple(l[idx][0], l[idx][1])
                                for idx in range(self.num_layers)])

    # Build the RNN
    with tf.name_scope('LSTM'):
        rnn_output, _ = tf.nn.dynamic_rnn(cell, rnn_inputs, sequence_length=input_seq_lengths,
                                          initial_state=rnn_tuple_state, time_major=True)

Now I'm confused on how to reset the hidden state. 现在,我对如何重置隐藏状态感到困惑。 I've tried two solutions but it's not working : 我已经尝试了两种解决方案,但是没有用:

First solution 第一个解决方案

Reset the "hidden_state" variable with : 使用以下命令重置“ hidden_​​state”变量:

rnn_state_zero_op = hidden_state.assign(tf.zeros_like(hidden_state))

It does work and I think it's because the unstack and tuple construction are not "re-played" into the graph after running the rnn_state_zero_op operation. 它确实起作用,我认为这是因为在运行rnn_state_zero_op操作后,未将unstack和tuple构造“重播”到图中。

Second solution 第二解决方案

Following LSTMStateTuple vs cell.zero_state() for RNN in Tensorflow I tried to reset the cell state with : Tensorflow中针对RNN的LSTMStateTuple vs cell.zero_state()之后,我尝试使用以下方法重置单元状态:

rnn_state_zero_op = cell.zero_state(self.batch_size, tf.float32)

It doesn't seem to work either. 它似乎也不起作用。

Question

I've another solution in mind but it's guessing at best : I'm not keeping the state returned by tf.nn.dynamic_rnn, I've thought of it but I get a tuple and I can't find a way to build an op to reset the tuple. 我想到了另一个解决方案,但充其量只能是在猜测:我没有保持tf.nn.dynamic_rnn返回的状态,我已经想到了,但是得到了一个元组,但是我找不到找到构建状态的方法。 op重置元组。

At this point I've to admit that I don't quite understand the internal working of tensorflow and if it's even possible to do what I'm trying to do. 在这一点上,我不得不承认我不太了解张量流的内部工作原理,甚至是否有可能做我想做的事情。 Is there a proper way to do it ? 有适当的方法吗?

Thanks ! 谢谢 !

Thanks to this answer to another question I was able to find a way to have complete control on whether or not (and when) the internal state of the RNN should be reset to 0. 由于对另一个问题的回答,我得以找到一种方法来完全控制是否(以及何时)将RNN的内部状态重置为0。

First you need to define some variables to store the state of the RNN, this way you will have control over it : 首先,您需要定义一些变量来存储RNN的状态,这样您就可以对其进行控制:

with tf.variable_scope('Hidden_state'):
    state_variables = []
    for state_c, state_h in cell.zero_state(self.batch_size, tf.float32):
        state_variables.append(tf.nn.rnn_cell.LSTMStateTuple(
            tf.Variable(state_c, trainable=False),
            tf.Variable(state_h, trainable=False)))
    # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
    rnn_tuple_state = tuple(state_variables)

Note that this version define directly the variables used by the LSTM, this is much better than the version in my question because you don't have to unstack and build the tuple, which add some ops to the graph that you cannot run explicitly. 请注意,该版本直接定义了LSTM使用的变量,这比我所问的版本要好得多,因为您不必拆栈并构建元组,这会在图上添加一些操作,而您不能明确地运行它们。

Secondly build the RNN and retrieve the final state : 其次,构建RNN并获取最终状态:

# Build the RNN
with tf.name_scope('LSTM'):
    rnn_output, new_states = tf.nn.dynamic_rnn(cell, rnn_inputs,
                                               sequence_length=input_seq_lengths,
                                               initial_state=rnn_tuple_state,
                                               time_major=True)

So now you have the new internal state of the RNN. 因此,现在您有了RNN的新内部状态。 You can define two ops to manage it. 您可以定义两个操作来管理它。

The first one will update the variables for the next batch. 第一个将更新下一批的变量。 So in the next batch the "initial_state" of the RNN will be fed with the final state of the previous batch : 因此,在下一批中,RNN的“ initial_state”将被馈入上一批的最终状态:

# Define an op to keep the hidden state between batches
update_ops = []
for state_variable, new_state in zip(rnn_tuple_state, new_states):
    # Assign the new state to the state variables on this layer
    update_ops.extend([state_variable[0].assign(new_state[0]),
                       state_variable[1].assign(new_state[1])])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
rnn_keep_state_op = tf.tuple(update_ops)

You should add this op to your session anytime you want to run a batch and keep the internal state. 您想运行批处理并保持内部状态的任何时间,都应将此op添加到会话中。

Beware : if you run batch 1 with this op called then batch 2 will start with the batch 1 final state, but if you don't call it again when running batch 2 then batch 3 will start with batch 1 final state also. 请注意 :如果使用此op来运行批处理1,则批处理2将以批处理1的最终状态开始,但是如果您在运行批处理2时不再次调用它,则批处理3也将以批处理1的最终状态开始。 My advice is to add this op every time you run the RNN. 我的建议是每次运行RNN时都要添加此操作。

The second op will be used to reset the internal state of the RNN to zeros: 第二个操作将用于将RNN的内部状态重置为零:

# Define an op to reset the hidden state to zeros
update_ops = []
for state_variable in rnn_tuple_state:
    # Assign the new state to the state variables on this layer
    update_ops.extend([state_variable[0].assign(tf.zeros_like(state_variable[0])),
                       state_variable[1].assign(tf.zeros_like(state_variable[1]))])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
rnn_state_zero_op = tf.tuple(update_ops)

You can call this op whenever you want to reset the internal state. 每当您想重置内部状态时,都可以调用此操作。

Simplified version of AMairesse post for one LSTM layer: 一层LSTM层的AMairesse柱的简化版本:

zero_state = tf.zeros(shape=[1, units[-1]])
self.c_state = tf.Variable(zero_state, trainable=False)
self.h_state = tf.Variable(zero_state, trainable=False)
self.init_encoder = tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state)

self.output_encoder, self.state_encoder = tf.nn.dynamic_rnn(cell_encoder, layer, initial_state=self.init_encoder)

# save or reset states
self.update_ops += [self.c_state.assign(self.state_encoder.c, use_locking=True)]
self.update_ops += [self.h_state.assign(self.state_encoder.h, use_locking=True)]

or you can use replacement for init_encoder to reset states at step == 0 (you need to pass self.step_tf into session.run() as placeholder): 或者您可以使用替代init_encoder在步骤== 0重置状态(您需要将self.step_tf作为占位符传递到session.run()中):

self.step_tf = tf.placeholder_with_default(tf.constant(-1, dtype=tf.int64), shape=[], name="step")

self.init_encoder = tf.cond(tf.equal(self.step_tf, 0),
  true_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(zero_state, zero_state),
  false_fn=lambda: tf.nn.rnn_cell.LSTMStateTuple(self.c_state, self.h_state))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM