[英]Converting GRU layer from PyTorch to TensorFlow
我正在嘗試將以下 GRU 層從 PyTorch(1.9.1) 轉換為 TensorFlow(2.6.0):
# GRU layer
self.gru = nn.GRU(64, 32, bidirectional=True, num_layers=2, dropout=0.25, batch_first=True)
我不確定我目前的執行,尤其是參數的轉換bidirectional
和num_layers
。 我目前的重建如下:
# GRU Layer
model.add(Bidirectional(GRU(32, return_sequences=True, dropout=0.25, time_major=False)))
model.add(Bidirectional(GRU(32, return_sequences=True, dropout=0.25, time_major=False)))
我錯過了什么嗎? 提前感謝您的幫助!
是的,這兩個模型是相同的,至少從參數數量和輸出形狀的角度來看:在 pytorch 中:
import torch
model = torch.nn.Sequential(torch.nn.GRU(64, 32, bidirectional=True, num_layers=2, dropout=0.25, batch_first=True))
from torchinfo import summary
batch_size = 16
summary(model, input_size=(batch_size, 100, 64))
> ========================================================================================== Layer (type:depth-idx) Output Shape
> Param #
> ========================================================================================== Sequential -- --
> ├─GRU: 1-1 [16, 100, 64]
> 37,632
> Total params: 37,632 Trainable params: 37,632 Non-trainable params: 0
> Total mult-adds (M): 60.21
> ============================================================================= Input size (MB): 0.41 Forward/backward pass size (MB): 0.82 Params
> size (MB): 0.15 Estimated Total Size (MB): 1.38
> =============================================================================
在張量流中:
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Bidirectional, GRU
# GRU Layer
model = Sequential()
model.add(Bidirectional(GRU(32, return_sequences=True, dropout=0.25, time_major=False)))
model.add(Bidirectional(GRU(32, return_sequences=True, dropout=0.25, time_major=False)))
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), loss='mse')
a = model.call(inputs=tf.random.normal(shape=(16, 100, 64)))
model.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
bidirectional_8 (Bidirection (16, 100, 64) 18816
_________________________________________________________________
bidirectional_9 (Bidirection (16, 100, 64) 18816
=================================================================
Total params: 37,632
Trainable params: 37,632
Non-trainable params: 0
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.