简体   繁体   English

keras - 检查带有嵌入层的目标时出错

[英]keras - Error when checking target with embedding layer

I'm trying to run keras model as follows:我正在尝试按如下方式运行 keras 模型:

model = Sequential()
model.add(Dense(10, activation='relu',input_shape=(286,)))
model.add(Dense(1, activation='softmax',input_shape=(324827, 286)))

This code works, but if I'm trying to add an embedding layer:此代码有效,但如果我尝试添加嵌入层:

model = Sequential()
model.add(Embedding(286,64, input_shape=(286,)))
model.add(Dense(10, activation='relu',input_shape=(286,)))
model.add(Dense(1, activation='softmax',input_shape=(324827, 286)))

I'm getting the following error :我收到以下错误:

ValueError: Error when checking target: expected dense_2 to have 3 dimensions, but got array with shape (324827, 1)

My data have 286 features and 324827 rows.我的数据有 286 个特征和 324827 行。 I'm probably doing something wrong with the shape definitions, can you tell me what it is ?我可能在形状定义上做错了什么,你能告诉我它是什么吗? Thanks谢谢

You don't need to provide the input_shape in the second Dense layer, and neither the first one, only on the first layer, the following layers shape will be coomputed :您不需要在第二个 Dense 层中提供 input_shape,第一个也不需要,仅在第一层上,将计算以下图层形状:

from tensorflow.keras.layers import Embedding, Dense
from tensorflow.keras.models import Sequential

# 286 features and 324827 rows (324827, 286)

model = Sequential()
model.add(Embedding(286,64, input_shape=(286,)))
model.add(Dense(10, activation='relu'))
model.add(Dense(1, activation='softmax'))
model.compile(loss='mse', optimizer='adam')
model.summary()

returns :返回:

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
embedding_2 (Embedding)      (None, 286, 64)           18304     
_________________________________________________________________
dense_2 (Dense)              (None, 286, 10)           650       
_________________________________________________________________
dense_3 (Dense)              (None, 286, 1)            11        
=================================================================
Total params: 18,965
Trainable params: 18,965
Non-trainable params: 0
_________________________________________________________________

I hope it's what you're looking for我希望这是你要找的

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM