简体   繁体   English

在Tensorflow中的运行之间保存LSTM RNN状态

[英]Saving LSTM RNN state between runs in Tensorflow

What's the best way to save the LSTM state between runs in Tensorflow? 在Tensorflow中运行之间保存LSTM状态的最佳方法是什么? For the prediction phase, I need to pass in data one timestep at a time because the input of the next timestep relies on the output of the previous timestep. 对于预测阶段,我需要一次一步地传入数据,因为下一个时间步的输入依赖于前一个时间步的输出。

I used the suggestion from this post: Tensorflow, best way to save state in RNNs? 我使用了这篇文章中的建议: Tensorflow,在RNN中保存状态的最佳方法? and tested it by passing in the same input over and over again without running the optimizer. 并通过在不运行优化器的情况下反复传递相同的输入来测试它。 If I understand correctly, if the output changes each time then it is saving the state but if it stays the same then it isn't. 如果我理解正确,如果输出每次都改变,那么它将保存状态,但如果它保持不变,则不是。 The result was that it saves the state the first time but then stays the same. 结果是它第一次保存状态但保持不变。

Here's my code: 这是我的代码:

 pieces = data_generator.load_pieces(5)

 batches = 100
 sizes = [126, 122]
 steps = 128
 layers = 2

 x = tf.placeholder(tf.float32, shape=[batches, steps, sizes[0]])
 y_ = tf.placeholder(tf.float32, shape=[batches, steps, sizes[1]])

 W = tf.Variable(tf.random_normal([sizes[0], sizes[1]]))
 b = tf.Variable(tf.random_normal([sizes[1]]))

 layer = tf.nn.rnn_cell.BasicLSTMCell(sizes[0], forget_bias=0.0)
 lstm = tf.nn.rnn_cell.MultiRNNCell([layer] * layers)

 # ~~~~~ code from linked post ~~~~~
 def get_state_variables(batch_size, cell):
     # For each layer, get the initial state and make a variable out of it
     # to enable updating its value.
     state_variables = []
     for state_c, state_h in cell.zero_state(batch_size, tf.float32):
         state_variables.append(tf.nn.rnn_cell.LSTMStateTuple(
             tf.Variable(state_c, trainable=False),
             tf.Variable(state_h, trainable=False)))
     # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
     return tuple(state_variables)

 states = get_state_variables(batches, lstm)

 outputs, new_states = tf.nn.dynamic_rnn(lstm, x, initial_state=states, dtype=tf.float32)

 def get_state_update_op(state_variables, new_states):
     # Add an operation to update the train states with the last state tensors
     update_ops = []
     for state_variable, new_state in zip(state_variables, new_states):
         # Assign the new state to the state variables on this layer
         update_ops.extend([state_variable[0].assign(new_state[0]),
                            state_variable[1].assign(new_state[1])])
     # Return a tuple in order to combine all update_ops into a single operation.
     # The tuple's actual value should not be used.
     return tf.tuple(update_ops)

 update_op = get_state_update_op(states, new_states)
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

 output = tf.reshape(outputs, [-1, sizes[0]])
 y = tf.nn.sigmoid(tf.matmul(output, W) + b)
 y = tf.reshape(y, [-1, steps, sizes[1]])

 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), [1, 2]))
 # train_step = tf.train.AdadeltaOptimizer().minimize(cross_entropy)

 sess = tf.InteractiveSession()
 sess.run(tf.global_variables_initializer())
 batch_x, batch_y = data_generator.get_batch(pieces)
 for i in range(500):
     error, _ = sess.run([cross_entropy, update_op], feed_dict={x: batch_x, y_: batch_y})
     print str(i) + ': ' + str(error)

Here's the error over time: 这是一段时间内的错误:

  • 0: 419.861 0:419.861
  • 1: 419.756 1:419.756
  • 2: 419.756 2:419.756
  • 3: 419.756 ... 3:419.756 ......

I recommend you this answer which i tried few days ago. 我推荐你几天前我试过的这个答案 It works well. 它运作良好。

By the way, there's a way avoid setting state_is_tuple to false : 顺便说一句,有一种方法可以避免将state_is_tuple设置为false

class CustomLSTMCell(tf.contrib.rnn.LSTMCell):
    def __init__(self, *args, **kwargs):
        # kwargs['state_is_tuple'] = False # force the use of a concatenated state.
        returns = super(CustomLSTMCell, self).__init__(
            *args, **kwargs)  # create an lstm cell
        # change the output size to the state size
        self._output_size = np.sum(self._state_size)
        return returns

    def __call__(self, inputs, state):
        output, next_state = super(
            CustomLSTMCell, self).__call__(inputs, state)
        # return two copies of the state, instead of the output and the state
        return tf.reshape(next_state, shape=[1, -1]), next_state

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM