[英]tensorflow multilabel classification - Incompatible shapes: [7,5] vs. [7]
I'm trying to create a multi-label classifier and have ran into an issue.我正在尝试创建一个多标签分类器并遇到了问题。 I have 5 classes and am getting stuck when trying to train the network, I am relatively new to machine learning and this is the first multi-label classifier I've built.
我有 5 个类,在尝试训练网络时卡住了,我对机器学习比较陌生,这是我构建的第一个多标签分类器。
My code:我的代码:
```
def createModel(learn, act):
model = models.Sequential()
model.add(layers.Conv2D(32, (9,9), activation=act, input_shape=(512,512,1)))
model.add(layers.AveragePooling2D((2,2)))
model.add(layers.Conv2D(64, (9, 9), activation=act))
model.add(layers.AveragePooling2D((2,2)))
model.add(layers.Conv2D(64, (6, 6), activation=act))
model.add(layers.MaxPooling2D((2,2)))
model.add(layers.Conv2D(96, (6, 6), activation=act))
model.add(layers.MaxPooling2D((2,2)))
model.add(layers.Conv2D(128, (3, 3), activation=act))
model.add(layers.MaxPooling2D((2,2)))
model.add(layers.Conv2D(128, (3, 3), activation=act))
model.add(layers.Flatten())
model.add(layers.Dense(128, activation=act))
model.add(layers.Dense(5, activation='sigmoid'))
model.compile(optimizer=optimizers.Adam(learning_rate=learn), loss='binary_crossentropy', metrics=['accuracy'])
return model
model = createModel(0.005, 'tanh')
History = model.fit(Xtrain, ytrain, epochs=300, validation_data=(Xtest, ytest), verbose=0)
```
I utilise my own split function due to my dataset being quite weirdly formatted, therefore I have to create my own labels with pre-existing data which is then run through a hot encoder.我使用自己的拆分 function 因为我的数据集的格式非常奇怪,因此我必须使用预先存在的数据创建自己的标签,然后通过热编码器运行。 Producing labels like s0:
生成像 s0 这样的标签:
```array([[[1., 0.],
[1., 0.],
[1., 0.],
[0., 1.],
[1., 0.]]```
I'm using 10 pieces of image arrays as a test which is split 70% train 30% tests but when I start to train the network, the following error occurs: >我正在使用 10 张图像 arrays 作为测试,该测试分为 70% 训练 30% 测试,但是当我开始训练网络时,出现以下错误:>
```Incompatible shapes: [7,5] vs. [7]
[[node Equal (defined at <ipython-input-54-eb6611e36e68>:3) ]] [Op:__inference_train_function_4978]```
What does this mean and how can I fix it?这是什么意思,我该如何解决?
I removed the hot encoder, which reverted the labels into the previous form of [0.,0.,0.,1.,0.] which then allowed me to train the network.我移除了热编码器,它将标签恢复为以前的 [0.,0.,0.,1.,0.] 形式,然后我可以训练网络。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.