簡體   English   中英

當state_is_tuple = True時,如何設置TensorFlow RNN狀態?

[英]How do I set TensorFlow RNN state when state_is_tuple=True?

使用TensorFlow編寫了一個RNN語言模型 該模型實現為RNN類。 圖結構是在構造函數中構建的,而RNN.trainRNN.test方法則運行它。

我想在移動到訓練集中的新文檔時,或者當我想在訓練期間運行驗證集時,能夠重置RNN狀態。 我通過管理訓練循環內的狀態,通過提要字典將其傳遞到圖表中來實現此目的。

在構造函數中,我像這樣定義RNN

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
                                                  initial_state=self.state)

訓練循環看起來像這樣

 for document in document:
     state = session.run(self.reset_state)
     for x, y in document:
          _, state = session.run([self.train_step, self.next_state], 
                                 feed_dict={self.x:x, self.y:y, self.state:state})

xy是文檔中的批量訓練數據。 我的想法是,每次批處理后都會傳遞最新的狀態,除非我啟動一個新文檔,當我通過運行self.reset_state將狀態歸零時。

這一切都有效。 現在我想更改我的RNN以使用推薦的state_is_tuple=True 但是,我不知道如何通過提要字典傳遞更復雜的LSTM狀態對象。 另外我不知道在self.state = tf.placeholder(...)函數中傳遞給self.state = tf.placeholder(...)行的參數。

這里的正確策略是什么? 可用的dynamic_rnn仍然沒有太多示例代碼或文檔。


TensorFlow問題26952838似乎相關。

關於WILDML的博客文章解決了這些問題,但沒有直接說明答案。

另請參見TensorFlow:記住下一批次的LSTM狀態(有狀態LSTM)

Tensorflow占位符的一個問題是你只能用Python列表或Numpy數組(我認為)來提供它。 因此,您無法在LSTMStateTuple的元組中的運行之間保存狀態。

我通過將狀態保存在這樣的張量中來解決這個問題

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

LSTM層中有兩個組件,即單元狀態隱藏狀態 ,這就是“2”的來源。 (這篇文章很棒: https//arxiv.org/pdf/1506.00019.pdf

構建圖形時,解壓縮並創建元組狀態,如下所示:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

然后你通常的方式得到新的狀態

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

它應該不是這樣的......也許他們正在努力解決問題。

在RNN狀態下饋送的簡單方法是單獨地饋送狀態元組的兩個分量。

# Constructing the graph
self.state = rnn_cell.zero_state(...)
self.output, self.next_state = tf.nn.dynamic_rnn(
    rnn_cell,
    self.input,
    initial_state=self.state)

# Running with initial state
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input
})

# Running with subsequent state:
output, state = sess.run([self.output, self.next_state], feed_dict={
    self.input: input,
    self.state[0]: state[0],
    self.state[1]: state[1]
})

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM