簡體   English   中英

Tensorflow:使用注意和 BeamSearch 的 seq2seq 模型中的 .clone() 問題

[英]Tensorflow: Troubles with .clone() in seq2seq model using Attention and BeamSearch

我正在嘗試在 Tensorflow (1.6.0) 中使用 bidirectional_dynamic_decode、Attention 和 BeamSearchDecoder 來實現 seq2seq 模型。 (為了簡單起見,我嘗試只復制相關代碼)

# encoder
def make_lstm(rnn_size, keep_prob):
    lstm = tf.nn.rnn_cell.LSTMCell(rnn_size, initializer = 
    tf.truncated_normal_initializer(mean = 0.0, stddev = 1.0))
    lstm_dropout = tf.nn.rnn_cell.DropoutWrapper(lstm, input_keep_prob 
    = keep_prob)
    return lstm_dropout    

cell_fw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) 
          for _ in range(n_layers)])
cell_bw = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) 
          for _ in range(n_layers)])
enc_output, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
                                                        cell_bw,
                                                        rnn_inputs, 
                                       sequence_length=sequence_length, 
                                                      dtype=tf.float32)
enc_output = tf.concat(enc_output,2)


dec_cell = tf.nn.rnn_cell.MultiRNNCell([make_lstm(rnn_size, keep_prob) 
                                        for _ in range(num_layers)])
output_layer = Dense(vocab_size, kernel_initializer = 
                     tf.truncated_normal_initializer(mean = 0.0, 
                      stddev=0.1))

# training_decoding_layer
with tf.variable_scope('decode'):
....

# inference_decoding_layer
with tf.variable_scope('decode', reuse = True):
    beam_width = 10
    tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(enc_output, 
     multiplier=beam_width)
    tiled_encoder_final_state = 
    tf.contrib.seq2seq.tile_batch(enc_state, multiplier=beam_width)
    tiled_sequence_length = tf.contrib.seq2seq.tile_batch(text_length, 
    multiplier=beam_width)
    start_tokens = tf.tile(tf.constant([word2ind['<GO>']], dtype = 
    tf.int32), [batch_size], name = 'start_tokens')
    attn_mech = tf.contrib.seq2seq.BahdanauAttention( num_units = 
    rnn_size, 
                                                      memory = 
    tiled_encoder_outputs,

    memory_sequence_length=tiled_sequence_length,
                                                      normalize=True )

    beam_dec_cell = tf.contrib.seq2seq.AttentionWrapper(dec_cell, 
    attn_mech, rnn_size)  

    beam_initial_state = beam_dec_cell.zero_state(batch_size = 
    batch_size*beam_width , dtype = tf.float32)
    beam_initial_state = 
    beam_initial_state.clone(cell_state=tiled_encoder_final_state)

但是,當我嘗試將編碼器的最后一個狀態克隆到前面圖中的“beam_initial_state”變量時,出現以下錯誤:

ValueError: The two structures don't have the same number of elements.

First structure (6 elements): AttentionWrapperState(cell_state= . 
(LSTMStateTuple(c=<tf.Tensor 
'decode_1/AttentionWrapperZeroState/checked_cell_state:0' shape=(640, 
256) dtype=float32>, h=<tf.Tensor 
'decode_1/AttentionWrapperZeroState/checked_cell_state_1:0' shape=(640, 
256) dtype=float32>),), attention=<tf.Tensor 
'decode_1/AttentionWrapperZeroState/zeros_1:0' shape=(640, 256) 
dtype=float32>, time=<tf.Tensor 
'decode_1/AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>, 
alignments=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_2:0' 
shape=(640, ?) dtype=float32>, alignment_history=(), attention_state= 
<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_3:0' shape=(640, 
?) dtype=float32>)

Second structure (8 elements): AttentionWrapperState(cell_state= . 
((LSTMStateTuple(c=<tf.Tensor 'decode_1/tile_batch_1/Reshape:0' shape= 
(?, 256) dtype=float32>, h=<tf.Tensor 
'decode_1/tile_batch_1/Reshape_1:0' shape=(?, 256) dtype=float32>),), 
(LSTMStateTuple(c=<tf.Tensor 'decode_1/tile_batch_1/Reshape_2:0' shape= 
(?, 256) dtype=float32>, h=<tf.Tensor 
'decode_1/tile_batch_1/Reshape_3:0' shape=(?, 256) dtype=float32>),)), 
attention=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_1:0' 
shape=(640, 256) dtype=float32>, time=<tf.Tensor 
'decode_1/AttentionWrapperZeroState/zeros:0' shape=() dtype=int32>, 
alignments=<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_2:0' 
shape=(640, ?) dtype=float32>, alignment_history=(), attention_state= . 
<tf.Tensor 'decode_1/AttentionWrapperZeroState/zeros_3:0' shape=(640, 
?) dtype=float32>)

有人有什么建議嗎? 非常感謝。

您需要為每個 MultiRNNCell 手動添加(或連接)前進和后退狀態:

def add_stacked_cell_state(forward_state, backword_state, useGRUCell):
    temp_list = []
    for state_fw, state_bw in zip(forward_state, backword_state):
        if useGRUCell:
            temp_list.append(tf.add(state_fw, state_bw))
            stacked_state = tuple(temp_list)
        else:
            temp_list2 = []
            for hidden_fw, hidden_bw in zip(state_fw, state_bw):
                temp_list2.append(tf.add(hidden_fw, hidden_bw))

            LSTMtuple = tf.contrib.rnn.LSTMStateTuple(*temp_list2)
            temp_list.append(LSTMtuple)

    stacked_state = tuple(temp_list)
    return stacked_state

然后應用它:

enc_output, enc_state = tf.nn.bidirectional_dynamic_rnn(cell_fw,
                                                    cell_bw,
                                                    rnn_inputs, 
                                   sequence_length=sequence_length, 
                                                  dtype=tf.float32)

enc_state = add_stacked_cell_state(enc_state[0], enc_state[1], useGRUCell=False)

enc_output = tf.add(enc_output[0],enc_output[1])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM