简体   繁体   English

如何在 tensorflow 2.x 中应用等效的 LSTM?

[英]How to apply equivalent LSTM in tensorflow 2.x?

I used tf.contrib layer to write recurrent neural network in TensorFlow.我使用 tf.contrib 层在 TensorFlow 中编写循环神经网络。 I made LSTM cell type first and extract the output and states by passing this cell into another layer.我首先制作了 LSTM 单元类型,然后通过将该单元传递到另一层来提取 output 和状态。 But in TensorFlow 2.x it seems like it can be done in a single line但是在 TensorFlow 2.x 中,它似乎可以在一行中完成

output, state_h, state_c = layers.LSTM(self.args.embedding_size, return_state=True, name="encoder")(tf.nn.embedding_lookup(self.embeddings, self.neighborhood_placeholder)

and I can't apply dropout warpper like in tensorflow 1.x.而且我不能像在 tensorflow 1.x 中那样应用 dropout warpper。 How may I convert the following codes into tensorflow 2.x?如何将以下代码转换为 tensorflow 2.x?

with tf.variable_scope('LSTM'):
            cell = tf.contrib.rnn.DropoutWrapper(
                    tf.contrib.rnn.LayerNormBasicLSTMCell(num_units=self.args.embedding_size, layer_norm=False),
                    input_keep_prob=1.0, output_keep_prob=1.0)
            _, states = tf.nn.dynamic_rnn(
                    cell,
                    tf.nn.embedding_lookup(self.embeddings, self.neighborhood_placeholder),
                    dtype=tf.float32,
                    sequence_length=self.seqlen_placeholder)
            self.lstm_output = states.h

Replace tf.contrib.rnn.DropoutWrapper with tf.compat.v1.nn.rnn_cell.DropoutWrapper .tf.contrib.rnn.DropoutWrapper替换为tf.compat.v1.nn.rnn_cell.DropoutWrapper

Replace tf.contrib.rnn.LayerNormBasicLSTMCell with tf.compat.v1.nn.rnn_cell.LSTMCelltf.contrib.rnn.LayerNormBasicLSTMCell替换为tf.compat.v1.nn.rnn_cell.LSTMCell

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何解决 LSTM 问题中的 loss: nan &accuracy: 0.0000e+00? 张量流 2.x - How to solve loss: nan & accuracy: 0.0000e+00 in a LSTM problem? Tensorflow 2.x Keras LSTM (TensorFlow 2.x) 中的动态批量大小 - Dynamic batch size in Keras LSTM (TensorFlow 2.x) 如何有效地将数据输入 TensorFlow 2.x, - How to efficiently feed data into TensorFlow 2.x, 如何从 TensorFlow 1.x 迁移到 TensorFlow 2.x - How to migrate from TensorFlow 1.x to TensorFlow 2.x 我尝试使用功能 API 在 tensorflow 2.x 中创建 model,但出现 LSTM 层不兼容错误 - I tried to create model in tensorflow 2.x using functional API, but got LSTM layers incompatible error 如何将输入数据传递给 Java 中现有的 tensorflow 2.x 模型? - How to pass input data to an existing tensorflow 2.x model in Java? 我们如何将 tensorflow 2.x 模型导入 Java? - How can we import a tensorflow 2.x model to Java? 如何在 tensorflow 2.x 中正确操作 tfds.load() 数据集? - How to manipulate tfds.load() datasets correctly in tensorflow 2.x? 如何预处理 Tensorflow 2.x 中实现的 BERT model 的数据集? - How to preprocess a dataset for BERT model implemented in Tensorflow 2.x? 如何获得符号渐变 [Tensorflow 2.x] - How can I get the symbolic gradient [Tensorflow 2.x]
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM