繁体   English   中英

如何使用Keras API在TensorFlow Eager中将压差应用于RNN的输出?

[英]How to apply dropout to the outputs of an RNN in TensorFlow Eager using the Keras API?

我想将辍学应用于RNN的输出。 例如,在Tensorflow 1.8.0中,我可以这样做:

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tfe.enable_eager_execution()

x = tf.random_uniform((10, 5, 3))

gru_cell1 = tf.contrib.rnn.GRUCell(2)
gru_cell1 = tf.contrib.rnn.DropoutWrapper(gru_cell1, output_keep_prob=0.5)
cell = tf.contrib.rnn.MultiRNNCell([gru_cell1])
init_state = cell.zero_state(10, tf.float32)

cell_output, _ = tf.nn.dynamic_rnn(cell, x,
                                   initial_state=init_state, time_major=False)
cell_output

如何使用Keras API实现同一件事?

我已经想到了以下两种方法,但它们均未成功:

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tfe.enable_eager_execution()

# Attempt 1
x = tf.random_uniform((10, 5, 3))

gru_layer = tf.keras.layers.GRU(2, return_sequences=True, input_shape=(10, 5, 3))
gru_layer = tf.keras.layers.Dropout(0.5)(gru_layer)

# Gives the following error:
# ValueError: Attempt to convert a value (<tensorflow.python.keras._impl.keras.layers.recurrent.GRU object
#  at 0x000001C520681F60>) with an unsupported type (<class 'tensorflow.python.keras._impl.keras.layers.recurrent.GRU'>) 
# to a Tensor.

# Attempt 2
x = tf.random_uniform((10, 5, 3))

gru_layer = tf.keras.layers.GRU(2, return_sequences=True, input_shape=(10, 5, 3))
gru_layer = tf.keras.layers.TimeDistributed(tf.keras.layers.Dropout(0.4))(gru_layer)

# Gives the following error:
# ValueError: as_list() is not defined on an unknown TensorShape.

要获得模型输出而无需培训,就像在TF代码中所做的那样,以下代码应该可以工作。 实际上,您需要一个Input层,并将每一层都连接到上一层,还需要一个Model

import numpy as np
from keras.models import Model
from keras.layers import Dropout, GRU, Input

x = np.random.randn(10, 5, 3)

inputs = Input(shape=(5, 3))
gru_layer = GRU(2, return_sequences=True)(inputs)
gru_layer = Dropout(0.5)(gru_layer)

model = Model(inputs=inputs, outputs=gru_layer)

output = model.predict(x)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM