[英]ValueError: Shapes (None, 1) and (None, 64) are incompatible Keras
[英]How to fix Keras ValueError: Shapes (None, 3, 2) and (None, 2) are incompatible?
下面的代码给了我错误ValueError: Shapes (None, 3, 2) and (None, 2) are incompatible
。 我想做的是构建一个多任务网络。 我该如何解决? 我正在使用 Tensorflow 2.3.0。
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras import Model
base_model = tf.keras.applications.EfficientNetB7(input_shape=(32,32, 3), weights='imagenet',
include_top=False) # or weights='noisy-student'
for layer in base_model.layers[:]:
layer.trainable = False
x = GlobalAveragePooling2D()(base_model.output)
dropout_rate = 0.3
x = Dense(256, activation='relu')(x)
x = Dropout(dropout_rate)(x)
x = Dense(256, activation='relu')(x)
x = Dropout(dropout_rate)(x)
all_target = []
loss_list = []
test_metrics = {}
for name, node in [("task1", 2), ("task2", 2), ("task3", 2)]:
y1 = Dense(128, activation='relu')(x)
y1 = Dropout(dropout_rate)(y1)
y1 = Dense(64, activation='relu')(y1)
y1 = Dropout(dropout_rate)(y1)
# y1 = Dense(64, activation='relu')(y1)
# y1 = Dropout(dropout_rate)(y1)
y1 = Dense(node, activation='softmax', name=name)(y1)
all_target.append(y1)
loss_list.append('categorical_crossentropy')
test_metrics[name] = "accuracy"
# model = Model(inputs=model_input, outputs=[y1, y2, y3])
model = Model(inputs=base_model.input, outputs=all_target)
model.compile(loss=loss_list, optimizer='adam', metrics=test_metrics)
res=np.random.randint(2, size=3072).reshape(32, 32, 3)
res=np.expand_dims(res, 0)
lab=np.array([[[0,1], [0,1], [0,1]]])
history = model.fit(res, y=lab, epochs=1, verbose=1)
可以想象,错误是由目标的形状引起的。 Keras 期望以下内容:
3 个 NumPy 数组(用于您的三个任务)的列表,形状为 (n_samples, n_categories)
训练将使用此行成功运行:
lab = [np.array([[0, 1]]), np.array([[0, 1]]), np.array([[0, 1]])]
我们有一个不同的版本,但是在运行您的代码时,我遇到了一个提供更多信息的错误:
ValueError:检查模型目标时出错:您传递给模型的 Numpy 数组列表不是模型预期的大小。 期望看到 3 个数组,对于输入 ['task1', 'task2', 'task3'] 但得到以下 1 个数组的列表: [array([[[0, 1], [0, 1] , [0, 1]]]])]...
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.