繁体   English   中英

InvalidArgumentError with model.fit in Tensorflow

[英]InvalidArgumentError with model.fit in Tensorflow

使用 CNN 进行图像分类。 当调用model.fit()时,它开始训练 model 一段时间,并在执行过程中中断并返回错误消息。

错误信息如下

InvalidArgumentError: 2 root error(s) found.
  (0) Invalid argument:  Input size should match (header_size + row_size * abs_height) but they differ by 2
     [[{{node decode_image/DecodeImage}}]]
     [[IteratorGetNext]]
     [[IteratorGetNext/_4]]
  (1) Invalid argument:  Input size should match (header_size + row_size * abs_height) but they differ by 2
     [[{{node decode_image/DecodeImage}}]]
     [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_train_function_8873]

Function call stack:
train_function -> train_function

更新:我的建议是检查数据集的元数据。 它帮助解决了我的问题。

您不必指定参数label_mode 为了使用SparseCategoricalCrossentropy作为损失 function 您需要将其设置为int 如果您未指定它,则根据文档将其设置为None

您还需要根据从中读取图像的目录结构指定要inferred的参数labels

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  labels="inferred",
  label_mode="int",
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)
  
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  data_dir,
  labels="inferred",
  label_mode="int",
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

我刚刚在另一个帖子中回答了一个非常相似的问题。 事实上,根本问题可能完全相同。 它包含对至少在我的情况下发生的事情的详细解释。 长话短说,我证明是正确的一个可能原因是JPEG 文件损坏

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM