簡體   English   中英

Keras / Tensorflow輸入到RNN層

[英]Keras/Tensorflow Input to RNN layers

我正在嘗試在Keras中建立RNN。 我不太了解所需的輸入格式。 我可以建立密集網絡,沒問題,但是我認為RNN層期望輸入尺寸x批處理x時間步長? 有人可以驗證嗎?

這是我要更新的代碼:

原始代碼:

def get_generative(G_in, dense_dim=200, out_dim=50, lr=1e-3):
   x = Dense(dense_dim)(G_in)
   x = Activation('tanh')(x)
   G_out = Dense(out_dim, activation='tanh')(x)
   G = Model(G_in, G_out)
   opt = SGD(lr=lr)
   G.compile(loss='binary_crossentropy', optimizer=opt)
   return G, G_out

G_in = Input(shape=[10])
G, G_out = get_generative(G_in)
G.summary()

修改了GRU圖層和一些稍微不同的尺寸:

def get_generative(G_in, dense_dim=10, out_dim=37, lr=1e-3):
   clear_session()
   x = GRU(dense_dim, activation='tanh',return_state=True)(G_in)
   G_out = GRU(out_dim, return_state=True)(x)
   G = Model(G_in, G_out)
   opt = SGD(lr=lr)
   G.compile(loss='binary_crossentropy', optimizer=opt)
   return G, G_out

G_in = Input(shape=(None,3))
G, G_out = get_generative(G_in)
G.summary()

我在此代碼中看到的錯誤是:ValueError:Tensor(“ gru_1 / strided_slice:0”,shape =(3,10),dtype = float32)必須與Tensor(“ strided_slice_1:0”)來自同一圖, shape =(?, 3),dtype = float32)。

如果刪除上面的“ None”,則會得到:ValueError:輸入0與gru_1層不兼容:預期ndim = 3,找到ndim = 2

任何解釋在這里都會有所幫助。

因為創建輸入張量后清除了會話,所以會出現錯誤。 這就是為什么輸入張量與網絡的其余部分來自不同的圖的原因。 要解決此問題,只需省去clear_session()

您的代碼還有另一個問題:第二個GRU層需要一個序列輸入,因此您應該在第一個GRU層中使用return_sequences=True 您可能想要省略參數return_state=True因為這會使該層返回張量(輸出和狀態)的元組,而不僅僅是一個輸出張量。

總結起來,下面的代碼應該做到這一點:

def get_generative(G_in, dense_dim=10, out_dim=37, lr=1e-3):
   x = GRU(dense_dim, activation='tanh', return_sequences=True)(G_in)
   G_out = GRU(out_dim)(x)
   G = Model(G_in, G_out)
   opt = SGD(lr=lr)
   G.compile(loss='binary_crossentropy', optimizer=opt)
   return G, G_out

這里的問題是RNN層需要以下形式的3D張量輸入:[數量樣本​​,時間步長,特征]。

因此,我們可以將上面的代碼修改為:

def get_generative(G_in, dense_dim=10, out_dim=37, lr=1e-3):
   x = GRU(dense_dim, activation='tanh',return_state=True)(G_in)
   G_out = GRU(out_dim, return_state=True)(x)
   G = Model(G_in, G_out)
   opt = SGD(lr=lr)
   G.compile(loss='binary_crossentropy', optimizer=opt)
   return G, G_out

G_in = Input(shape=(1,3))
G, G_out = get_generative(G_in)
G.summary()

因此,我們要說的是,我們希望輸入任意數量的樣本,每個樣本具有1個時間步長和3個特征。

安娜是正確的,clear_session()不應在生成器函數中。

最后,如果您確實想將數據輸入到網絡中,則其形狀也應與我們剛剛討論的形狀匹配。 您可以使用numpy reshape來做到這一點:

X = np.reshape(X, (X.shape[0], 1, X.shape[1]))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM