繁体   English   中英

如何使用 tf.data 管道训练图像分类模型?

[英]How to train image classification model using tf.data pipeline?

我正在训练大班图像分类模型。 由于数据集很大,它不适合 ram 内存,所以我使用tf.data管道将其存储在缓存中并在训练进行时读取。 我构建的整体代码如下

file_location = r'C:\imageclassification\data'
data_dir = pathlib.Path(file_location)
class_labels_cnt = len(os.listdir(file_location))
print('total class :',class_labels_cnt)
image_count = len(list(data_dir.glob('*/*.png')))

batch_size = 64
img_height = 224
img_width = 224
epochs = 500
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)


class_names = np.array(sorted([item.name for item in data_dir.glob('*[!.csv]') if item.name != "LICENSE.txt"]))
print('class_names:',class_names)
val_size = int(image_count * 0.2)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)


print(tf.data.experimental.cardinality(train_ds).numpy())
print(tf.data.experimental.cardinality(val_ds).numpy())

def get_label(file_path):
  # Convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  # The second to last is the class-directory
  one_hot = parts[-2] == class_names
  # Integer encode the label
  return tf.argmax(one_hot)


def decode_img(img):
  # Convert the compressed string to a 3D uint8 tensor
  img = tf.io.decode_jpeg(img, channels=3)
  # Resize the image to the desired size
  return tf.image.resize(img, [img_height, img_width])
  return img
def process_path(file_path):
  label = get_label(file_path)
  # Load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label
AUTOTUNE = tf.data.AUTOTUNE
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)

for image, label in train_ds.take(1):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())

def configure_for_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(buffer_size=AUTOTUNE)
  return ds
model = build_resnet_pretrained(input_shape=(224, 224, 3), no_classes=class_labels_cnt)

train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)



model.compile( loss='categorical_crossentropy',
               optimizer=Adam(amsgrad=True, decay=0.001 / epochs),
               metrics=['accuracy']
             )

history = model.fit(
    x=train_ds,
    epochs=epochs,
    steps_per_epoch=int(tf.data.experimental.cardinality(train_ds).numpy()/ batch_size),
    callbacks=callbacks,
    validation_data=val_ds,
    validation_steps=int(tf.data.experimental.cardinality(val_ds).numpy() / batch_size)
)

当我尝试运行它时,我最终收到如下错误

> WARNING:tensorflow:`period` argument is deprecated. Please use
> `save_freq` to specify the frequency in number of batches seen. Epoch
> 1/500 Traceback (most recent call last):   File
> "C:\Users\john\Desktop\John_data\Project\All_Scripts\Scripts\Deep_Learning\ResNet\tfdata_clsfn.py",
> line 113, in <module>
>     history = model.fit(   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py",
> line 1184, in fit
>     tmp_logs = self.train_function(iterator)   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 885, in __call__
>     result = self._call(*args, **kwds)   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 933, in _call
>     self._initialize(args, kwds, add_initializers_to=initializers)   File
> "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 759, in _initialize
>     self._stateful_fn._get_concrete_function_internal_garbage_collected( 
> # pylint: disable=protected-access   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\function.py",
> line 3066, in _get_concrete_function_internal_garbage_collected
>     graph_function, _ = self._maybe_define_function(args, kwargs)   File
> "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\function.py",
> line 3463, in _maybe_define_function
>     graph_function = self._create_graph_function(args, kwargs)   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\function.py",
> line 3298, in _create_graph_function
>     func_graph_module.func_graph_from_py_func(   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\framework\func_graph.py",
> line 1007, in func_graph_from_py_func
>     func_outputs = python_func(*func_args, **func_kwargs)   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 668, in wrapped_fn
>     out = weak_wrapped_fn().__wrapped__(*args, **kwds)   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\framework\func_graph.py",
> line 994, in wrapper
>     raise e.ag_error_metadata.to_exception(e) ValueError: in user code:
> 
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:853
> train_function  *
>         return step_function(self, iterator)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:842
> step_function  **
>         outputs = model.distribute_strategy.run(run_step, args=(data,))
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1286
> run
>         return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2849
> call_for_each_replica
>         return self._call_for_each_replica(fn, args, kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3632
> _call_for_each_replica
>         return fn(*args, **kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:835
> run_step  **
>         outputs = model.train_step(data)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:788
> train_step
>         loss = self.compiled_loss(
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\compile_utils.py:201
> __call__
>         loss_value = loss_obj(y_t, y_p, sample_weight=sw)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\losses.py:141
> __call__
>         losses = call_fn(y_true, y_pred)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\losses.py:245
> call  **
>         return ag_fn(y_true, y_pred, **self._fn_kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\util\dispatch.py:206
> wrapper
>         return target(*args, **kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\losses.py:1665
> categorical_crossentropy
>         return backend.categorical_crossentropy(
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\util\dispatch.py:206
> wrapper
>         return target(*args, **kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\backend.py:4839
> categorical_crossentropy
>         target.shape.assert_is_compatible_with(output.shape)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\framework\tensor_shape.py:1161
> assert_is_compatible_with
>         raise ValueError("Shapes %s and %s are incompatible" % (self, other))
> 
>     ValueError: Shapes (None, 1) and (None, 1922) are incompatible
> 
> 
> Process finished with exit code 1

我无法找出错误的原因,即使类的数量与输出 softmax 层数相同 任何纠正此错误的帮助或建议将不胜感激

我通过将损失函数更改为稀疏分类交叉熵来解决此问题

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM