简体   繁体   中英

LSTM layer does not accept the input shape of CNN layer output

I am trying to create a CNN + LSTM.network, but the LSTM layer is not accepting the input shape. Is there anything I can do?

model = Sequential()
model.add(Conv2D(128, (2,2), padding = 'same', input_shape=(30, 216, 1)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))

model.add(Conv2D(256, (2,2), padding = 'same'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))

model.add(LSTM(512, input_shape = (7, 54, 256,)))
model.add(Flatten())
model.add(Dense(7, activation='softmax'))

ValueError: Input 0 of layer lstm_21 is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: [None, 7, 54, 256]

The LSTM layer in Keras expects this format as input:

inputs: A 3D tensor with shape [batch, timesteps, feature].

For this reason, you can't pass non-recurrent layers directly. First, Flatten() the layer before, and wrap that layer into a TimeDistributed layer, like this:

model.add(TimeDistributed(Flatten()))
model.add(LSTM(8))

This TimeDistributed layer allows to apply a layer to every temporal slice of an input . Here's a fully working example:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, LSTM, \
    Dense, Flatten, Dropout, MaxPooling2D, Activation, TimeDistributed
import numpy as np

X = np.random.rand(100, 30, 216, 1)
y = np.random.randint(0, 7, 100)

model = Sequential()
model.add(Conv2D(16, (2,2), padding = 'same', input_shape=(30, 216, 1)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))

model.add(Conv2D(32, (2,2), padding = 'same'))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))
model.add(TimeDistributed(Flatten()))
model.add(LSTM(8))
model.add(Dense(7, activation='softmax'))

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
history = model.fit(X, y)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM