繁体   English   中英

TensorFlow2:ResNet50 - ValueError

[英]TensorFlow2: ResNet50 - ValueError

我正在尝试在 TensorFlow2 中使用 ResNet-50 和 Keras 在具有(32、32、3)图像的 CIFAR-10 数据集上使用迁移学习。

默认 ResNet-50 的第一个 conv 层使用 (7, 7) 的过滤器大小,步幅 = 2,由此产生的 CIFAR-10 在空间上减少了太多,这是应该避免的。 作为“黑客”,尝试将图像从 (32, 32) 放大到 (224, 224)。 代码是:

import tensorflow.keras as K

# Define KerasTensor as input-
input_t = K.Input(shape = (32, 32, 3))

res_model = K.applications.ResNet50(
    include_top = False,
    weights = "imagenet",
    input_tensor = input_t
)

# Since CIFAR-10 dataset is small as compared to ImageNet, the images are upscaled to (224, 224)-
to_res = (224, 224)

model = K.models.Sequential()
model.add(K.layers.Lambda(lambda image: tf.image.resize(image, to_res))) 
model.add(res_model)
model.add(K.layers.Flatten())
model.add(K.layers.BatchNormalization())
model.add(K.layers.Dense(units = 10, activation = 'softmax'))

# Choose an optimizer and loss function for training-
loss_fn = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.SGD(learning_rate = 0.1, momentum = 0.9)

model.compile(
    # loss = 'categorical_crossentropy',
    loss = loss_fn,
    # optimizer = K.optimizers.RMSprop(lr=2e-5),
    optimizer = optimizer,
    metrics=['accuracy']
)

history = model.fit(
    x = X_train, y = y_train,
    batch_size = batch_size, epochs = 10,
    validation_data = (X_test, y_test),
    # callbacks=[check_point]
    )

我得到错误:

纪元 1/10 警告:tensorflow:Model 是用形状 (None, 32, 32, 3) 构造的,用于输入 KerasTensor(type_spec=TensorSpec(shape=(None, 32, 32, 32), name=32, 'input_1'), name='input_1', description="created by layer 'input_1'"),但它是在形状不兼容的输入上调用的 (None, 224, 224, 3)。


ValueError Traceback(最近一次调用最后一次)

in () 2 x = X_train, y = y_train, 3 batch_size = batch_size, epochs = 10, ----> 4 validation_data = (X_test, y_test), 5 # callbacks=[check_point] 6)

9帧

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs) 975 except Exception as e: # pylint:disable=broad-except 976 if hasattr(e, "ag_error_metadata"): --> 977 raise e.ag_error_metadata.to_exception(e) 978 else: 979 raise

ValueError:在用户代码中:

ValueError: Input 0 is in compatible with layer resnet50: expected shape=(None, 32, 32, 3), found shape=(None, 224, 224, 3)

model的输入还是(32,32,3)

input_t = K.Input(shape = (32, 32, 3))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM