簡體   English   中英

在Tensorflow中的運行之間保存LSTM RNN狀態

[英]Saving LSTM RNN state between runs in Tensorflow

在Tensorflow中運行之間保存LSTM狀態的最佳方法是什么? 對於預測階段,我需要一次一步地傳入數據,因為下一個時間步的輸入依賴於前一個時間步的輸出。

我使用了這篇文章中的建議: Tensorflow,在RNN中保存狀態的最佳方法? 並通過在不運行優化器的情況下反復傳遞相同的輸入來測試它。 如果我理解正確,如果輸出每次都改變,那么它將保存狀態,但如果它保持不變,則不是。 結果是它第一次保存狀態但保持不變。

這是我的代碼:

 pieces = data_generator.load_pieces(5)

 batches = 100
 sizes = [126, 122]
 steps = 128
 layers = 2

 x = tf.placeholder(tf.float32, shape=[batches, steps, sizes[0]])
 y_ = tf.placeholder(tf.float32, shape=[batches, steps, sizes[1]])

 W = tf.Variable(tf.random_normal([sizes[0], sizes[1]]))
 b = tf.Variable(tf.random_normal([sizes[1]]))

 layer = tf.nn.rnn_cell.BasicLSTMCell(sizes[0], forget_bias=0.0)
 lstm = tf.nn.rnn_cell.MultiRNNCell([layer] * layers)

 # ~~~~~ code from linked post ~~~~~
 def get_state_variables(batch_size, cell):
     # For each layer, get the initial state and make a variable out of it
     # to enable updating its value.
     state_variables = []
     for state_c, state_h in cell.zero_state(batch_size, tf.float32):
         state_variables.append(tf.nn.rnn_cell.LSTMStateTuple(
             tf.Variable(state_c, trainable=False),
             tf.Variable(state_h, trainable=False)))
     # Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
     return tuple(state_variables)

 states = get_state_variables(batches, lstm)

 outputs, new_states = tf.nn.dynamic_rnn(lstm, x, initial_state=states, dtype=tf.float32)

 def get_state_update_op(state_variables, new_states):
     # Add an operation to update the train states with the last state tensors
     update_ops = []
     for state_variable, new_state in zip(state_variables, new_states):
         # Assign the new state to the state variables on this layer
         update_ops.extend([state_variable[0].assign(new_state[0]),
                            state_variable[1].assign(new_state[1])])
     # Return a tuple in order to combine all update_ops into a single operation.
     # The tuple's actual value should not be used.
     return tf.tuple(update_ops)

 update_op = get_state_update_op(states, new_states)
 # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

 output = tf.reshape(outputs, [-1, sizes[0]])
 y = tf.nn.sigmoid(tf.matmul(output, W) + b)
 y = tf.reshape(y, [-1, steps, sizes[1]])

 cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), [1, 2]))
 # train_step = tf.train.AdadeltaOptimizer().minimize(cross_entropy)

 sess = tf.InteractiveSession()
 sess.run(tf.global_variables_initializer())
 batch_x, batch_y = data_generator.get_batch(pieces)
 for i in range(500):
     error, _ = sess.run([cross_entropy, update_op], feed_dict={x: batch_x, y_: batch_y})
     print str(i) + ': ' + str(error)

這是一段時間內的錯誤:

  • 0:419.861
  • 1:419.756
  • 2:419.756
  • 3:419.756 ......

我推薦你幾天前我試過的這個答案 它運作良好。

順便說一句,有一種方法可以避免將state_is_tuple設置為false

class CustomLSTMCell(tf.contrib.rnn.LSTMCell):
    def __init__(self, *args, **kwargs):
        # kwargs['state_is_tuple'] = False # force the use of a concatenated state.
        returns = super(CustomLSTMCell, self).__init__(
            *args, **kwargs)  # create an lstm cell
        # change the output size to the state size
        self._output_size = np.sum(self._state_size)
        return returns

    def __call__(self, inputs, state):
        output, next_state = super(
            CustomLSTMCell, self).__call__(inputs, state)
        # return two copies of the state, instead of the output and the state
        return tf.reshape(next_state, shape=[1, -1]), next_state

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM