简体   繁体   中英

Getting dimensions right for a single layer keras LSTM

I have some hard time to get the dimensions of a LSTM network right.

So I have the following data:

train_data.shape
 (25391, 3) # to be read as 25391 timesteps and 3 features

train_labels.shape
 (25391, 1) # to be read as 25391 timesteps and 1 feature

So I have thought my input dimension is (1, len(train_data), train_data.shape[1]) as I plan to submit 1 batch. But I get the following error:

Error when checking target: expected lstm_10 to have 2 dimensions, but got array with shape (1, 25391, 1)

Here is the model code:

model = Sequential()
model.add(LSTM(1, # predict one feature and one timestep
               batch_input_shape=(1, len(train_data), train_data.shape[1]),
               activation='tanh',
               return_sequences=False))

model.compile(loss = 'categorical_crossentropy', optimizer='adam', metrics = ['accuracy'])
print(model.summary())

# as 1 sample with len(train_data) time steps and train_data.shape[1] features.
model.fit(x=train_data.values.reshape(1, len(train_data), train_data.shape[1]), 
          y=train_labels.values.reshape(1, len(train_labels), train_labels.shape[1]), 
          epochs=1, 
          verbose=1, 
          validation_split=0.8, 
          validation_data=None, 
          shuffle=False)

How should the input dimensions look like?

The problem is in the target (ie labels) shape you provide (ie Error when checking target ). The output of LSTM layer in your model, which is also the output of the model, has a shape of (None, 1) since you are specifying to only the final output to be returned (ie return_sequences=False ). In order to have the output of each timestep you need to set return_sequences=True . This way the output shape of LSTM layer would be (None, num_timesteps, num_units) which is consistent with the shape of labels array you provide.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM