簡體   English   中英

如何使用Keras API在TensorFlow Eager中將壓差應用於RNN的輸出?

[英]How to apply dropout to the outputs of an RNN in TensorFlow Eager using the Keras API?

我想將輟學應用於RNN的輸出。 例如,在Tensorflow 1.8.0中,我可以這樣做:

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tfe.enable_eager_execution()

x = tf.random_uniform((10, 5, 3))

gru_cell1 = tf.contrib.rnn.GRUCell(2)
gru_cell1 = tf.contrib.rnn.DropoutWrapper(gru_cell1, output_keep_prob=0.5)
cell = tf.contrib.rnn.MultiRNNCell([gru_cell1])
init_state = cell.zero_state(10, tf.float32)

cell_output, _ = tf.nn.dynamic_rnn(cell, x,
                                   initial_state=init_state, time_major=False)
cell_output

如何使用Keras API實現同一件事?

我已經想到了以下兩種方法,但它們均未成功:

import tensorflow as tf
import tensorflow.contrib.eager as tfe

tfe.enable_eager_execution()

# Attempt 1
x = tf.random_uniform((10, 5, 3))

gru_layer = tf.keras.layers.GRU(2, return_sequences=True, input_shape=(10, 5, 3))
gru_layer = tf.keras.layers.Dropout(0.5)(gru_layer)

# Gives the following error:
# ValueError: Attempt to convert a value (<tensorflow.python.keras._impl.keras.layers.recurrent.GRU object
#  at 0x000001C520681F60>) with an unsupported type (<class 'tensorflow.python.keras._impl.keras.layers.recurrent.GRU'>) 
# to a Tensor.

# Attempt 2
x = tf.random_uniform((10, 5, 3))

gru_layer = tf.keras.layers.GRU(2, return_sequences=True, input_shape=(10, 5, 3))
gru_layer = tf.keras.layers.TimeDistributed(tf.keras.layers.Dropout(0.4))(gru_layer)

# Gives the following error:
# ValueError: as_list() is not defined on an unknown TensorShape.

要獲得模型輸出而無需培訓,就像在TF代碼中所做的那樣,以下代碼應該可以工作。 實際上,您需要一個Input層,並將每一層都連接到上一層,還需要一個Model

import numpy as np
from keras.models import Model
from keras.layers import Dropout, GRU, Input

x = np.random.randn(10, 5, 3)

inputs = Input(shape=(5, 3))
gru_layer = GRU(2, return_sequences=True)(inputs)
gru_layer = Dropout(0.5)(gru_layer)

model = Model(inputs=inputs, outputs=gru_layer)

output = model.predict(x)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM