[英]Tensorflow: Troubles with .clone() in seq2seq model using Attention and BeamSearch
我正在嘗試在 Tensorflow (1.6.0) 中使用 bidirectional_dynamic_decode、Attention 和 BeamSearchDecoder 來實現 seq2seq 模型。 (為了簡單起見,我嘗試只復制相關代碼)
# encoder
def make_lstm(rnn_size, keep_prob):
lstm = tf.nn.rnn_cell.LSTMCell(rnn_size, initializer =
tf.truncated_normal_initializer(mean = 0.0, stddev = 1.0))
lstm_dropout = tf.nn.rnn_cell.DropoutWrapper(lstm, input_keep_prob
= keep_prob)
return lstm_dropout
cell_fw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob)
for _ in range(n_layers)])
cell_bw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob)
for _ in range(n_layers)])
enc_output, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
cell_bw,
rnn_inputs,
sequence_length=sequence_length,
dtype=tf.float32)
enc_output = tf.concat(enc_output,2)
dec_cell = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob)
for _ in range(num_layers)])
output_layer = Dense(vocab_size, kernel_initializer =
tf.truncated_normal_initializer(mean = 0.0,
stddev=0.1))
# training_decoding_layer
with tf.variable_scope('decode'):
....
# inference_decoding_layer
with tf.variable_scope('decode', reuse = True):
beam_width = 10
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(enc_output,
multiplier=beam_width)
tiled_encoder_final_state =
tf.contrib.seq2seq.tile_batch(enc_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(text_length,
multiplier=beam_width)
start_tokens = tf.tile(tf.constant([word2ind['<GO>']], dtype =
tf.int32), [batch_size], name = 'start_tokens')
attn_mech = tf.contrib.seq2seq.BahdanauAttention( num_units =
rnn_size,
memory =
tiled_encoder_outputs,
memory_sequence_length=tiled_sequence_length,
normalize=True )
beam_dec_cell = tf.contrib.seq2seq.AttentionWrapper(dec_cell,
attn_mech, rnn_size)
beam_initial_state = beam_dec_cell.zero_state(batch_size =
batch_size*beam_width , dtype = tf.float32)
beam_initial_state =
beam_initial_state.clone(cell_state=tiled_encoder_final_state)
但是,當我嘗試將編碼器的最后一個狀態克隆到前面圖中的“beam_initial_state”變量時,出現以下錯誤:
ValueError: The two structures don't have the same number of elements.
First structure (6 elements): AttentionWrapperState(cell_state= .
(LSTMStateTuple(c=<tf.Tensor
'decode_1/AttentionWrapperZeroState/checked_cell_state:0' shape=(640,
256) dtype=float32>, h=<tf.Tensor
'decode_1/AttentionWrapperZeroState/checked_cell_state_1:0' shape=(640,
256) dtype=float32>),), attention=<tf.Tensor
'decode_1/AttentionWrapperZeroState/zeros_1:0' shape=(640, 256)
dtype=float32>, time=<tf.Tensor
'decode_1/AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>,
alignments=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_2:0'
shape=(640, ?) dtype=float32>, alignment_history=(), attention_state=
<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_3:0' shape=(640,
?) dtype=float32>)
Second structure (8 elements): AttentionWrapperState(cell_state= .
((LSTMStateTuple(c=<tf.Tensor 'decode_1/tile_batch_1/Reshape:0' shape=
(?, 256) dtype=float32>, h=<tf.Tensor
'decode_1/tile_batch_1/Reshape_1:0' shape=(?, 256) dtype=float32>),),
(LSTMStateTuple(c=<tf.Tensor 'decode_1/tile_batch_1/Reshape_2:0' shape=
(?, 256) dtype=float32>, h=<tf.Tensor
'decode_1/tile_batch_1/Reshape_3:0' shape=(?, 256) dtype=float32>),)),
attention=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_1:0'
shape=(640, 256) dtype=float32>, time=<tf.Tensor
'decode_1/AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>,
alignments=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_2:0'
shape=(640, ?) dtype=float32>, alignment_history=(), attention_state= .
<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_3:0' shape=(640,
?) dtype=float32>)
有人有什么建議嗎? 非常感謝。
您需要為每個 MultiRNNCell 手動添加(或連接)前進和后退狀態:
def add_stacked_cell_state(forward_state, backword_state, useGRUCell):
temp_list = []
for state_fw, state_bw in zip(forward_state, backword_state):
if useGRUCell:
temp_list.append(tf.add(state_fw, state_bw))
stacked_state = tuple(temp_list)
else:
temp_list2 = []
for hidden_fw, hidden_bw in zip(state_fw, state_bw):
temp_list2.append(tf.add(hidden_fw, hidden_bw))
LSTMtuple = tf.contrib.rnn.LSTMStateTuple(*temp_list2)
temp_list.append(LSTMtuple)
stacked_state = tuple(temp_list)
return stacked_state
然后應用它:
enc_output, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
cell_bw,
rnn_inputs,
sequence_length=sequence_length,
dtype=tf.float32)
enc_state = add_stacked_cell_state(enc_state[0], enc_state[1], useGRUCell=False)
enc_output = tf.add(enc_output[0],enc_output[1])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.