[英]How to connect two LSTM models in Keras
我想用Keras創建一個帶有兩個LSTM層的模型。 但是,以下代碼會生成錯誤:
from keras.models import Sequential
from keras.layers import LSTM, Dropout, Activation
from keras.callbacks import ModelCheckpoint
from keras.utils import to_categorical
model = Sequential()
model.add(LSTM(5, activation="softmax"))
model.add(LSTM(5, activation="softmax"))
model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['categorical_accuracy'])
# These values are to be predicted.
directions = [-2, -1, 0, 1, 2]
# Sample data. We have three time steps, one
# feature per timestep, and one resulting value.
data = [[[[1], [2], [3]], -1],
[[[3], [2], [1]], 2],
[[[4], [5], [7]], 1],
[[[1], [-1], [10]], -2]]
X = []
y_ = []
# Now we take 10000 samples from the data above.
for i in np.random.choice(len(data), 10000):
X.append(data[i][0])
y_.append(data[i][1])
X = np.array(X)
y_ = np.array(y_)
y = to_categorical(y_ + 2, num_classes=5)
model.fit(X, y,
epochs=3,
validation_data=(X, y))
print(model.summary())
loss, acc = model.evaluate(X, y)
print("Loss: {:.2f}".format(loss))
print("Accuracy: {:.2f}%".format(acc*100))
我收到以下錯誤:
ValueError: Input 0 is incompatible with layer lstm_10: expected ndim=3, found ndim=2
完整的錯誤追溯:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-35-58fa9218c3f3> in <module>
31 model.fit(X, y,
32 epochs=3,
---> 33 validation_data=(X, y))
34 print(model.summary())
35
C:\Anaconda3\lib\site-packages\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
950 sample_weight=sample_weight,
951 class_weight=class_weight,
--> 952 batch_size=batch_size)
953 # Prepare validation data.
954 do_validation = False
C:\Anaconda3\lib\site-packages\keras\engine\training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
675 # to match the value shapes.
676 if not self.inputs:
--> 677 self._set_inputs(x)
678
679 if y is not None:
C:\Anaconda3\lib\site-packages\keras\engine\training.py in _set_inputs(self, inputs, outputs, training)
587 assert len(inputs) == 1
588 inputs = inputs[0]
--> 589 self.build(input_shape=(None,) + inputs.shape[1:])
590 return
591
C:\Anaconda3\lib\site-packages\keras\engine\sequential.py in build(self, input_shape)
219 self.inputs = [x]
220 for layer in self._layers:
--> 221 x = layer(x)
222 self.outputs = [x]
223 self._build_input_shape = input_shape
C:\Anaconda3\lib\site-packages\keras\layers\recurrent.py in __call__(self, inputs, initial_state, constants, **kwargs)
530
531 if initial_state is None and constants is None:
--> 532 return super(RNN, self).__call__(inputs, **kwargs)
533
534 # If any of `initial_state` or `constants` are specified and are Keras
C:\Anaconda3\lib\site-packages\keras\engine\base_layer.py in __call__(self, inputs, **kwargs)
412 # Raise exceptions in case the input is not compatible
413 # with the input_spec specified in the layer constructor.
--> 414 self.assert_input_compatibility(inputs)
415
416 # Collect input shapes to build layer.
C:\Anaconda3\lib\site-packages\keras\engine\base_layer.py in assert_input_compatibility(self, inputs)
309 self.name + ': expected ndim=' +
310 str(spec.ndim) + ', found ndim=' +
--> 311 str(K.ndim(x)))
312 if spec.max_ndim is not None:
313 ndim = K.ndim(x)
ValueError: Input 0 is incompatible with layer lstm_10: expected ndim=3, found ndim=2
看起來第一個LSTM層的輸出尺寸(假設dim = 2)與第二個LSTM層的所需輸入尺寸不匹配(dim = 3表示批次,時間步長,特征)。
讓我感到困惑的是,按照我的方式添加LSTM層似乎可以在這里工作,例如: https : //adventuresinmachinelearning.com/keras-lstm-tutorial/
刪除第二個LSTM圖層時,該模型有效。
默認情況下,LSTM僅在序列的最后一個元素之后返回它的最終輸出。 如果要將兩個鏈接在一起,則需要在序列的每個元素之后將輸出從第一個LSTM傳遞到第二個LSTM。 例如
model = Sequential()
model.add(LSTM(5, return_sequences=True))
model.add(LSTM(5, activation="softmax"))
有關return_sequence如何工作的詳細信息,請參閱文檔https://keras.io/layers/recurrent/
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.