[英]How to fit list of numpy array into LSTM Neural Network?
我是Keras的新手,非常感謝您的幫助。 對於我的項目,我正在嘗試在多個時間序列上訓練神經網絡。 我通過運行for循環使其適合模型的每個時間序列來使其工作。 代碼如下:
for i in range(len(train)):
history = model.fit(train_X[i], train_Y[i], epochs=20, batch_size=6, verbose=0, shuffle=True)
如果我沒看錯,我正在這里進行在線培訓。 現在,我正在嘗試進行批量培訓,以查看結果是否更好。 我試圖擬合一個包含所有時間序列的列表(每個時間序列都轉換為numpy數組),但是出現此錯誤:
Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 array(s), but instead got the following list of 56 arrays:
以下是有關數據集和模型的信息:
model = Sequential()
model.add(LSTM(1, input_shape=(1,16),return_sequences=True))
model.add(Flatten())
model.add(Dense(1, activation='tanh'))
model.compile(loss='mae', optimizer='adam', metrics=['accuracy'])
model.summary()
Layer (type) Output Shape Param #
=================================================================
lstm_2 (LSTM) (None, 1, 1) 72
_________________________________________________________________
flatten_2 (Flatten) (None, 1) 0
_________________________________________________________________
dense_2 (Dense) (None, 1) 2
=================================================================
Total params: 74
Trainable params: 74
Non-trainable params: 0
print(len(train_X), train_X[0].shape, len(train_Y), train_Y[0].shape)
56 (1, 23, 16) 56 (1, 23, 1)
這是給我錯誤的代碼塊:
pyplot.figure(figsize=(16, 25))
history = model.fit(train_X, train_Y, epochs=1, verbose=1, shuffle=False, batch_size = len(train_X))
LSTM的輸入形狀應為batch_size
, timesteps
, features
。但是,如果您希望可以使用batch_input_shape
則不必在輸入形狀中提及batch_size
。
model = Sequential()
model.add(LSTM(1, input_shape=(23,16),return_sequences=True))
# model.add(Flatten())
model.add(Dense(1, activation='tanh'))
model.compile(loss='mae', optimizer='adam', metrics=['accuracy'])
model.summary()
X = np.random.random((56,1, 23, 16))
y = np.random.random((56,1, 23, 1))
X=np.squeeze(X,axis =1) #as input shape should be (`batch_size`, `timesteps`, `features`)
y = np.squeeze(y,axis =1)
model.fit(X,y,epochs=1, verbose=1, shuffle=False, batch_size = len(X))
我不確定它是否符合您的目的。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.