![](/img/trans.png)
[英]M1 Max Tensorflow model.fit() InvalidArgumentError
[英]InvalidArgumentError with model.fit in Tensorflow
使用 CNN 进行图像分类。 当调用model.fit()
时,它开始训练 model 一段时间,并在执行过程中中断并返回错误消息。
错误信息如下
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: Input size should match (header_size + row_size * abs_height) but they differ by 2
[[{{node decode_image/DecodeImage}}]]
[[IteratorGetNext]]
[[IteratorGetNext/_4]]
(1) Invalid argument: Input size should match (header_size + row_size * abs_height) but they differ by 2
[[{{node decode_image/DecodeImage}}]]
[[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_8873]
Function call stack:
train_function -> train_function
更新:我的建议是检查数据集的元数据。 它帮助解决了我的问题。
您不必指定参数label_mode
。 为了使用SparseCategoricalCrossentropy
作为损失 function 您需要将其设置为int
。 如果您未指定它,则根据文档将其设置为None
。
您还需要根据从中读取图像的目录结构指定要inferred
的参数labels
。
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
labels="inferred",
label_mode="int",
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
labels="inferred",
label_mode="int",
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)
我刚刚在另一个帖子中回答了一个非常相似的问题。 事实上,根本问题可能完全相同。 它包含对至少在我的情况下发生的事情的详细解释。 长话短说,我证明是正确的一个可能原因是JPEG 文件损坏。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.