简体   繁体   中英

How to reshape tensor for Recurrent Neural Network for LSTM layer

I am trying to train a RNN, my X input shape is (5018, 481) and y label input shape is (5018,) . I have converted both X and y to tensors in the following format:

x_train_tensor = tf.convert_to_tensor(X, dtype=tf.float32)
y_train_tensor = tf.convert_to_tensor(y, dtype=tf.float32)

And then with the following RNN keras model architecture:

model = keras.Sequential([
    keras.layers.Dense(100, activation='elu', input_shape=(481,)),
    keras.layers.LSTM(64, return_sequences=False, dropout=0.1, recurrent_dropout=0.1),
    keras.layers.Dense(25, activation='elu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(1, 'elu')
])

opt = keras.optimizers.Adam(lr=0.001)
model.compile(optimizer=opt, loss='mean_squared_error', metrics=['mse'])

model.fit(x_train_tensor, y_train_tensor, epochs=8)

I get the following error

ValueError: Input 0 of layer lstm_1 is incompatible with the layer: 
expected ndim=3, found ndim=2. Full shape received: [None, 100]

Anyone have a solution?

It expects to get 3 dimensions: if you print summary for the network without the LSTM layer:

model = keras.Sequential([
keras.layers.Dense(100, activation='elu', input_shape=(481,)),
keras.layers.Dense(25, activation='elu'),
keras.layers.Dropout(0.5),
keras.layers.Dense(1,activation= 'elu')
])

opt = keras.optimizers.Adam(lr=0.001)
model.compile(optimizer=opt, loss='mean_squared_error', metrics=['mse'])
model.summary()

you get:

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_19 (Dense)             (None, 100)               48200     
_________________________________________________________________
dense_20 (Dense)             (None, 25)                2525      
_________________________________________________________________
dropout_7 (Dropout)          (None, 25)                0         
_________________________________________________________________
dense_21 (Dense)             (None, 1)                 26        
=================================================================
Total params: 50,751
Trainable params: 50,751
Non-trainable params: 0
_________________________________________________________________

so you the 1st layer has 2 dims of (None,100). the error message say that the LSTM layer require 3 dimensions, so expand 1 dim:

model = keras.Sequential([
    keras.layers.Dense(100, activation='elu', input_shape=(481,)),
    keras.layers.Reshape((100,1)),
    keras.layers.LSTM(64, return_sequences=False, dropout=0.1, recurrent_dropout=0.1),
    keras.layers.Dense(25, activation='elu'),
    keras.layers.Dropout(0.5),
    keras.layers.Dense(1,activation= 'elu')
])

opt = keras.optimizers.Adam(lr=0.001)
model.compile(optimizer=opt, loss='mean_squared_error', metrics=['mse'])
model.summary()

and you'll get:

Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_16 (Dense)             (None, 100)               48200     
_________________________________________________________________
reshape_2 (Reshape)          (None, 100, 1)            0         
_________________________________________________________________
lstm_4 (LSTM)                (None, 64)                16896     
_________________________________________________________________
dense_17 (Dense)             (None, 25)                1625      
_________________________________________________________________
dropout_6 (Dropout)          (None, 25)                0         
_________________________________________________________________
dense_18 (Dense)             (None, 1)                 26        
=================================================================
Total params: 66,747
Trainable params: 66,747
Non-trainable params: 0
_________________________________________________________________

hope I helped you.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM