简体   繁体   中英

Multichannel CNN-LSTM in Keras

I've a multidimensional time-series dataset with N features (dimensions). I'm building a CNN-LSTM model that has N input channels (one per feature). First the model should perform 1D convolutions on each feature, then merge the outputs and feed it to an LSTM layer. However, I'm having problems with the dimensions (I suspect this is the underlying issue), namely the merged output dimensions are not what expected.

I've tried Flatten() on each feature but it returns (?, ?), and Reshape() doesn't seem to do the trick either.

# Init the multichannel CNN-LSTM proto.
def mccnn_lstm(steps=window, feats=features, dim=1, f=filters, k=kernel, p=pool):
    channels, convs = [], []

    # Multichannel CNN layer
    for i in range(feats):
        chan = Input(shape=(steps, dim))
        conv = Conv1D(filters=f, kernel_size=k, activation="tanh")(chan)
        maxpool = MaxPooling1D(pool_size=p, strides=1)(conv) # Has shape (?, 8, 64)
        flat = Flatten()(maxpool) # Returns (?, ?), not (?, 8*64) as expected
        channels.append(chan)
        convs.append(flat)

    merged = concatenate(convs)  # Returns (?, ?), would expect a tensor like (?, 8*64, num of channels)

    # LSTM layer
    lstm = TimeDistributed(merged)
    lstm = LSTM(64)(merged) # This line raises the error
    dense = Dense(1, activation="sigmoid")(lstm)

    return Model(inputs=channels, outputs=dense)


model = mccnn_lstm()

Error message:

ValueError: Input 0 is incompatible with layer lstm_1: expected ndim=3, found ndim=2

I'd expect the merged output from the multichannel layer to have dimensions (?, 8*64, num of channels), or something similar, which would then be the input to the LSTM layer.

Are you using Keras? In that case, you didn't create any Sequential() model. That is likely to be a cause of error. Please let me know.

--

EDIT :

I think Flatten() is not needed in this model. Flatten was meant to feed the output of a conv2D() layer into a Dense() layer, ie to flatten a 2-dimensional object into a 1-dimensional vector. But if you are already working in 1D (with conv1D ) then Flatten() is not needed.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM