[英]How do I set TensorFlow RNN state when state_is_tuple=True?
我使用TensorFlow編寫了一個RNN語言模型 。 該模型實現為RNN
類。 圖結構是在構造函數中構建的,而RNN.train
和RNN.test
方法則運行它。
我想在移動到訓練集中的新文檔時,或者當我想在訓練期間運行驗證集時,能夠重置RNN狀態。 我通過管理訓練循環內的狀態,通過提要字典將其傳遞到圖表中來實現此目的。
在構造函數中,我像這樣定義RNN
cell = tf.nn.rnn_cell.LSTMCell(hidden_units)
rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)
self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)
self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")
self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,
initial_state=self.state)
訓練循環看起來像這樣
for document in document:
state = session.run(self.reset_state)
for x, y in document:
_, state = session.run([self.train_step, self.next_state],
feed_dict={self.x:x, self.y:y, self.state:state})
x
和y
是文檔中的批量訓練數據。 我的想法是,每次批處理后都會傳遞最新的狀態,除非我啟動一個新文檔,當我通過運行self.reset_state
將狀態歸零時。
這一切都有效。 現在我想更改我的RNN以使用推薦的state_is_tuple=True
。 但是,我不知道如何通過提要字典傳遞更復雜的LSTM狀態對象。 另外我不知道在self.state = tf.placeholder(...)
函數中傳遞給self.state = tf.placeholder(...)
行的參數。
這里的正確策略是什么? 可用的dynamic_rnn
仍然沒有太多示例代碼或文檔。
關於WILDML的博客文章解決了這些問題,但沒有直接說明答案。
Tensorflow占位符的一個問題是你只能用Python列表或Numpy數組(我認為)來提供它。 因此,您無法在LSTMStateTuple的元組中的運行之間保存狀態。
我通過將狀態保存在這樣的張量中來解決這個問題
initial_state = np.zeros((num_layers, 2, batch_size, state_size))
LSTM層中有兩個組件,即單元狀態和隱藏狀態 ,這就是“2”的來源。 (這篇文章很棒: https : //arxiv.org/pdf/1506.00019.pdf )
構建圖形時,解壓縮並創建元組狀態,如下所示:
state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
[tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
for idx in range(num_layers)]
)
然后你通常的方式得到新的狀態
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)
它應該不是這樣的......也許他們正在努力解決問題。
在RNN狀態下饋送的簡單方法是單獨地饋送狀態元組的兩個分量。
# Constructing the graph
self.state = rnn_cell.zero_state(...)
self.output, self.next_state = tf.nn.dynamic_rnn(
rnn_cell,
self.input,
initial_state=self.state)
# Running with initial state
output, state = sess.run([self.output, self.next_state], feed_dict={
self.input: input
})
# Running with subsequent state:
output, state = sess.run([self.output, self.next_state], feed_dict={
self.input: input,
self.state[0]: state[0],
self.state[1]: state[1]
})
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.