簡體   English   中英

tf.keras.layers.RNN 與 tf.keras.layers.StackedRNNCells:Tensorflow 2

[英]tf.keras.layers.RNN vs tf.keras.layers.StackedRNNCells: Tensorflow 2

我正在嘗試在 Tensorflow 2.0 中實現多層 RNN 模型。 嘗試tf.keras.layers.StackedRNNCellstf.keras.layers.RNN結果相同。 誰能幫我理解tf.keras.layers.RNNtf.keras.layers.StackedRNNCells之間的區別?

# driving parameters
sz_batch = 128
sz_latent = 200
sz_sequence = 196
sz_feature = 2
n_units = 120
n_layers = 3

多層 RNN 與tf.keras.layers.RNN

inputs = tf.keras.layers.Input(batch_shape=(sz_batch, sz_sequence, sz_feature))
cells = [tf.keras.layers.GRUCell(n_units) for _ in range(n_layers)]
outputs = tf.keras.layers.RNN(cells, stateful=True, return_sequences=True, return_state=False)(inputs)
outputs = tf.keras.layers.Dense(1)(outputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()

返回:

Model: "model_13"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_88 (InputLayer)        [(128, 196, 2)]           0         
_________________________________________________________________
rnn_61 (RNN)                 (128, 196, 120)           218880    
_________________________________________________________________
dense_19 (Dense)             (128, 196, 1)             121       
=================================================================
Total params: 219,001
Trainable params: 219,001
Non-trainable params: 0

具有tf.keras.layers.RNNtf.keras.layers.StackedRNNCells多層 RNN:

inputs = tf.keras.layers.Input(batch_shape=(sz_batch, sz_sequence, sz_feature))
cells = [tf.keras.layers.GRUCell(n_units) for _ in range(n_layers)]
outputs = tf.keras.layers.RNN(tf.keras.layers.StackedRNNCells(cells),
                              stateful=True, 
                              return_sequences=True, 
                              return_state=False)(inputs)
outputs = tf.keras.layers.Dense(1)(outputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()

返回:

Model: "model_14"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_89 (InputLayer)        [(128, 196, 2)]           0         
_________________________________________________________________
rnn_62 (RNN)                 (128, 196, 120)           218880    
_________________________________________________________________
dense_20 (Dense)             (128, 196, 1)             121       
=================================================================
Total params: 219,001
Trainable params: 219,001
Non-trainable params: 0

tf.keras.layers.RNN 使用 tf.keras.layers.StackedRNNCells 如果你給它一個列表或一個單元組。 這是在https://github.com/tensorflow/tensorflow/blob/v2.1.0/tensorflow/python/keras/layers/recurrent.py#L390 中完成的

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM