简体   繁体   English

Keras 功能 API 嵌入层 output 到 LSTM

[英]Keras Functional API embedding layer output to LSTM

When passing the output of my embedding layer to the LSTM layer I'm running into a ValueError that I cannot figure out.将嵌入层的 output 传递到 LSTM 层时,我遇到了我无法弄清楚的ValueError My model is:我的 model 是:

def lstm_mod(self, n_cells,batch_size):
        input = tf.keras.Input((self.n_seq, self.n_features))
        embedding = tf.keras.layers.Embedding(batch_size,self.n_seq,input_length=self.n_clusters)(input)
        x= tf.keras.layers.LSTM(n_cells)(embedding) 
        out = tf.keras.layers.Dense(1)(x)
        model = tf.keras.Model(input, out,name="LSTM")
        model.compile(loss='mse', optimizer='Adam')
        return model 

The error is:错误是:

ValueError: Input 0 of layer lstm is incompatible with the layer: expected ndim=3, found ndim=4. Full shape received: [None, 128, 7, 128]

Given that the dimensions passed to the model input and the embedding layer are consistent through the arguments of the model I'm puzzled by this.鉴于传递给 model 输入和嵌入层的尺寸通过 model 的 arguments 是一致的,我对此感到困惑。 Any guidance is appreciated.任何指导表示赞赏。

Keras adds an additional dimension ( None ) when you feed your data through your model because it processes your data in batches.当您通过 model 提供数据时,Keras 会添加一个额外的维度 ( None ),因为它会分批处理您的数据。

In this line:在这一行:

input = tf.keras.Input((self.n_seq, self.n_features))

You've defined a 2-dimensional input, and Keras adds a 3rd dimension (the batch), hence expected ndim=3 .您已经定义了一个二维输入,并且 Keras 添加了第三维(批次),因此expected ndim=3

However, the data that is being passed to the input layer is 4-dimensional, which means that your actual input data shape is 3-dimensional + the batch dimension, not 2-dimensional + batch.但是,传递给输入层的数据是 4 维的,这意味着您的实际输入数据形状是 3 维 + 批次维度,而不是 2 维 + 批次。

To fix this you need to either re-shape your 3-D input to 2-D, or add an additional dimension to the input shape.要解决此问题,您需要将 3-D 输入重新整形为 2-D,或者向输入形状添加额外的维度。

Print out the values for self.n_seq and self.n_features and find out what is missing from the shape 128, 7, 128 and that should guide you as to what you need to add.打印出self.n_seqself.n_features的值,找出形状 128、7、128 中缺少的内容128, 7, 128这将指导您了解需要添加的内容。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM