According to documentation, the LSTM layer should handle inputs with (None, CONST, CONST) shape. For variable timestep , it should be able to handle inputs with (None, None, CONST) shape .
Let say my data is the following :
X = [
[
[1, 2, 3],
[4, 5, 6]
],
[
[7, 8, 9]
]
]
Y = [0, 1]
And my model :
model = tf.keras.models.Sequential([
tf.keras.layers.LSTM(32, activation='tanh',input_shape=(None, 3)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(X, Y)
My question is: how should I format these inputs to make this code work ?
I cannot use pandas dataframes here, as I was used to. If I run the code above, I get this error :
Error when checking model input: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see 1 array(s), but instead got the following list of 2 arrays:
And if I change the last line with :
model.fit(np.array(X), np.array(Y))
The error is now :
Error when checking input: expected lstm_8_input to have 3 dimensions, but got array with shape (2, 1)
You are close but in Keras / Tensorflow you need to pad your sequences and then use a Masking to let the LSTM skip those padded ones. Why? Because your entries in your tensor need to have the same shape (batch_size, max_length, features)
. So if you have variable length, the sequence gets padded.
You can use keras.preprocessing.sequence.pad_sequences to pad your sequences to obtain something like:
X = [
[
[1, 2, 3],
[4, 5, 6]
],
[
[7, 8, 9],
[0, 0, 0],
]
]
X.shape == (2, 2, 3)
Y = [0, 1]
Y.shape == (2, 1)
And then use the masking layer:
model = tf.keras.models.Sequential([
tf.keras.layers.Masking(), # this tells LSTM to skip certain timesteps
tf.keras.layers.LSTM(32, activation='tanh',input_shape=(None, 3)),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(loss='binary_crossentropy', optimizer='adam')
model.fit(X, Y)
You also want binary_crossentropy
since you have a binary classification problem with sigmoid
output.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.