简体   繁体   中英

Convert unidirectional LSTM cell to Bidirectinoal LSTM cell in tensorflow 1.0

I have this legacy code that was inmplemented in tensorflow 1.0.1. I want to convert the current LSTM cell to bidirrectional LSTM.

with tf.variable_scope("encoder_scope") as encoder_scope:

cell = contrib_rnn.LSTMCell(num_units=state_size, state_is_tuple=True)
cell = DtypeDropoutWrapper(cell=cell, output_keep_prob=tf_keep_probabiltiy, dtype=DTYPE)
cell = contrib_rnn.MultiRNNCell(cells=[cell] * num_lstm_layers, state_is_tuple=True)

encoder_cell = cell

encoder_outputs, last_encoder_state = tf.nn.dynamic_rnn(
    cell=encoder_cell,
    dtype=DTYPE,
    sequence_length=encoder_sequence_lengths,
    inputs=encoder_inputs,
    )

I found some examples out there. https://riptutorial.com/tensorflow/example/17004/creating-a-bidirectional-lstm

But I cannot convert my LSTM cell to bidirectional LSTM cell by reffering them. What should be put into state_below in my case?

Update : Apart from above issue I need to clarify how to convert following decoder network (dynamic_rnn_decoder) to use bidirectional LSTM. (The documentation does not give any clue about that)

with tf.variable_scope("decoder_scope") as decoder_scope:

    decoder_cell = tf.contrib.rnn.LSTMCell(num_units=state_size)
    decoder_cell = DtypeDropoutWrapper(cell=decoder_cell, output_keep_prob=tf_keep_probabiltiy, dtype=DTYPE)
    decoder_cell = contrib_rnn.MultiRNNCell(cells=[decoder_cell] * num_lstm_layers, state_is_tuple=True)   

    # define decoder train netowrk
    decoder_outputs_tr, _ , _ = dynamic_rnn_decoder(
        cell=decoder_cell, # the cell function
        decoder_fn= simple_decoder_fn_train(last_encoder_state, name=None),
        inputs=decoder_inputs,
        sequence_length=decoder_sequence_lengths,
        parallel_iterations=None,
        swap_memory=False,
        time_major=False)

Can anyone please clarify?

You can use bidirectional_dynamic_rnn [1]

cell_fw = contrib_rnn.LSTMCell(num_units=state_size, state_is_tuple=True)
cell_fw = DtypeDropoutWrapper(cell=cell_fw, output_keep_prob=tf_keep_probabiltiy, dtype=DTYPE)
cell_fw = contrib_rnn.MultiRNNCell(cells=[cell_fw] * int(num_lstm_layers/2), state_is_tuple=True)

cell_bw = contrib_rnn.LSTMCell(num_units=state_size, state_is_tuple=True)
cell_bw = DtypeDropoutWrapper(cell=cell_bw, output_keep_prob=tf_keep_probabiltiy, dtype=DTYPE)
cell_bw = contrib_rnn.MultiRNNCell(cells=[cell_bw] * num_lstm_layers, state_is_tuple=True)

encoder_cell_fw = cell_fw
encoder_cell_bw = cell_bw

encoder_outputs, (output_state_fw, output_state_bw) = tf.nn.bidirectional_dynamic_rnn(
    cell_fw=encoder_cell_fw,
    cell_bw=encoder_cell_bw,
    dtype=DTYPE,
    sequence_length=encoder_sequence_lengths,
    inputs=encoder_inputs,
    )

last_encoder_state = [
                       tf.concat([output_state_fw[0], output_state_bw[0]], axis=-1),
                       tf.concat([output_state_fw[1], output_state_bw[1]], axis=-1)
                     ]

However, as it says in TensorFlow docs, this API is deprecated and you should consider moving to TensorFlow2 and use keras.layers.Bidirectional(keras.layers.RNN(cell))

Regarding the updated question, you cannot use bidirectional in the decoder model as bidirectional would mean it already knew what it still has to generate [2]

Anyway, to adapt your decoder to the bidirectional encoder you could concatenate the encoder states and double the decoder num_units (or half the num_units in the encoder) [3]

decoder_cell = tf.contrib.rnn.LSTMCell(num_units=state_size)
decoder_cell = DtypeDropoutWrapper(cell=decoder_cell, output_keep_prob=tf_keep_probabiltiy, dtype=DTYPE)
decoder_cell = contrib_rnn.MultiRNNCell(cells=[decoder_cell] * num_lstm_layers, state_is_tuple=True)   

# define decoder train netowrk
decoder_outputs_tr, _ , _ = dynamic_rnn_decoder(
    cell=decoder_cell, # the cell function
    decoder_fn= simple_decoder_fn_train(last_encoder_state, name=None),
    inputs=decoder_inputs,
    sequence_length=decoder_sequence_lengths,
    parallel_iterations=None,
    swap_memory=False,
    time_major=False)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM