简体   繁体   中英

Tensorflow tf.keras.layers.Reshape RNN/LSTM

I have a dataset with multi variables, I'm trying to reshape to feed in a LSTM Neural Nets, but I'm struggle with reshape layer without success.

My dataset has the shape (1921535, 6) and every 341 timesteps correspond to one sample. I'd like to reshape in (23, 341, 6) and feed it in the Model. Below my code.

def df_to_dataset(dataframe, batch_size):
   dataframe = dataframe.copy()
   labels = dataframe.pop('target')
   ds = tf.data.Dataset.from_tensor_slices((dict(dataframe), labels))
   ds = ds.batch(batch_size)
   return ds

max_length = 341
batch_size = 23

train_ds = df_to_dataset(data, batch_size * max_length)

model = tf.keras.Sequential([
   tf.keras.layers.Reshape((batch_size, max_length, 6), input_shape=(batch_size * max_length, 6)),
   tf.keras.layers.LSTM(40, return_sequences=True),
   tf.keras.layers.LSTM(40),
   tf.keras.layers.Dense(1)
   ])

When I run the code I've receive the following error:

InvalidArgumentError: Input to reshape is a tensor with 54901 values, but the requested shape has 369075894 [Op:Reshape]

Thanks in advance

I can not reproduce your error:

max_length = 341
batch_size = 23

model = tf.keras.Sequential([
   tf.keras.layers.Reshape((batch_size, max_length, 6), input_shape=(batch_size * max_length, 6)),
   tf.keras.layers.LSTM(40, return_sequences=True),
   tf.keras.layers.LSTM(40),
   tf.keras.layers.Dense(1)
   ])

Your code is incorrect: output of your Reshape layer has 4 dimensions (your 3 dims plus batch dimension), but LSTM expecting 3 dims.

I've solved the problem using Lambda layer with tf.reshape, I don't know why, but I can't use the Reshape Layer.

model = tf.keras.Sequential([
   tf.keras.layers.Lambda(lambda x: tf.reshape(x, (-1, max_length, 6)),
   tf.keras.layers.LSTM(40, return_sequences=True),
   tf.keras.layers.LSTM(40),
   tf.keras.layers.Dense(1)
   ])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM