[英]ValueError: Shapes (None, 1) and (None, 101) are incompatible
The output dense layer is 101 but this error pops up.输出密集层是 101 但会弹出这个错误。 The error could be with respect to train_dataset and test_dataset or model compiling.
错误可能与 train_dataset 和 test_dataset 或模型编译有关。 Help me out with this !!
帮我解决这个问题!!
Also if I want to sample the dataset and then assign the sample to training and testing.另外,如果我想对数据集进行采样,然后将样本分配给训练和测试。 How do I it?
我怎么办? I'm new to Tensorflow and python and the syntax is confusing.
我是 Tensorflow 和 python 的新手,语法令人困惑。
BATCH_SIZE = 32
IMG_SIZE = (224, 224)
directory = "/Food/Food 101/images"
train_dataset = image_dataset_from_directory(directory,
shuffle=True,
batch_size=BATCH_SIZE,
label_mode = "int",
image_size=IMG_SIZE,
validation_split=0.2,
subset='training',
seed=42)
validation_dataset = image_dataset_from_directory(directory,
shuffle=True,
batch_size=BATCH_SIZE,
label_mode = "int",
image_size=IMG_SIZE,
validation_split=0.2,
subset='validation',
seed=42)
class_names = train_dataset.class_names
plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
plt.title(class_names[labels[i]])
plt.axis("off")
AUTOTUNE = tf.data.experimental.AUTOTUNE
train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
def data_augmenter():
data_augmentation = tf.keras.Sequential()
data_augmentation.add(tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal"))
data_augmentation.add(tf.keras.layers.experimental.preprocessing.RandomRotation(0.2))
return data_augmentation
data_augmentation = data_augmenter()
for image, _ in train_dataset.take(1):
plt.figure(figsize=(10, 10))
first_image = image[0]
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
plt.imshow(augmented_image[0] / 255)
plt.axis('off')
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
IMG_SHAPE = IMG_SIZE + (3,)
IMG_SHAPE
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=True,
weights='imagenet')
base_model.summary()
nb_layers = len(base_model.layers)
print(base_model.layers[nb_layers - 2].name)
print(base_model.layers[nb_layers - 1].name)
image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
label_batch
base_model.trainable = False
image_var = tf.Variable(image_batch)
pred = base_model(image_var)
tf.keras.applications.mobilenet_v2.decode_predictions(pred.numpy(), top=2)
image_shape=IMG_SIZE
def FC_model(image_shape=IMG_SIZE, data_augmentation=data_augmenter()):
input_shape = image_shape + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=input_shape,
include_top=False,
weights='imagenet')
base_model.trainable = False
inputs = tf.keras.Input(shape = image_shape + (3, ))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dropout(0.2)(x)
prediction_layer = tf.keras.layers.Dense(101)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
return model
model2 = FC_model(IMG_SIZE, data_augmentation)
model2.summary()
base_learning_rate = 0.001
model2.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
metrics=['accuracy'])
initial_epochs = 10
history = model2.fit(train_dataset, validation_data=validation_dataset, epochs=initial_epochs)
The problem is in how you are loading the data:问题在于您如何加载数据:
train_dataset = image_dataset_from_directory(directory,
shuffle=True,
batch_size=BATCH_SIZE,
label_mode = "int",
image_size=IMG_SIZE,
validation_split=0.2,
subset='training',
seed=42)
In particular label_mode="int"
means that your target variable is encoded as an integer (ie, 1 if cat, 2 if dog, 3 if tree).特别是
label_mode="int"
意味着您的目标变量被编码为一个整数(即,1 表示猫,2 表示狗,3 表示树)。 You want to change it to label_mode="categorical"
.您想将其更改为
label_mode="categorical"
。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.