簡體   English   中英

如何在 tensorflow 2.x 中應用等效的 LSTM?

[英]How to apply equivalent LSTM in tensorflow 2.x?

我使用 tf.contrib 層在 TensorFlow 中編寫循環神經網絡。 我首先制作了 LSTM 單元類型,然后通過將該單元傳遞到另一層來提取 output 和狀態。 但是在 TensorFlow 2.x 中,它似乎可以在一行中完成

output, state_h, state_c = layers.LSTM(self.args.embedding_size, return_state=True, name="encoder")(tf.nn.embedding_lookup(self.embeddings, self.neighborhood_placeholder)

而且我不能像在 tensorflow 1.x 中那樣應用 dropout warpper。 如何將以下代碼轉換為 tensorflow 2.x?

with tf.variable_scope('LSTM'):
            cell = tf.contrib.rnn.DropoutWrapper(
                    tf.contrib.rnn.LayerNormBasicLSTMCell(num_units=self.args.embedding_size, layer_norm=False),
                    input_keep_prob=1.0, output_keep_prob=1.0)
            _, states = tf.nn.dynamic_rnn(
                    cell,
                    tf.nn.embedding_lookup(self.embeddings, self.neighborhood_placeholder),
                    dtype=tf.float32,
                    sequence_length=self.seqlen_placeholder)
            self.lstm_output = states.h

tf.contrib.rnn.DropoutWrapper替換為tf.compat.v1.nn.rnn_cell.DropoutWrapper

tf.contrib.rnn.LayerNormBasicLSTMCell替換為tf.compat.v1.nn.rnn_cell.LSTMCell

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM