简体   繁体   English

如何使用 tf.data 管道训练图像分类模型?

[英]How to train image classification model using tf.data pipeline?

I'm training large classes image classification model.我正在训练大班图像分类模型。 Since dataset is huge it wont fit inside ram memory so I'm using tf.data pipeline to store it in cache and read as training goes on.由于数据集很大,它不适合 ram 内存,所以我使用tf.data管道将其存储在缓存中并在训练进行时读取。 Overall code which I've built is as below我构建的整体代码如下

file_location = r'C:\imageclassification\data'
data_dir = pathlib.Path(file_location)
class_labels_cnt = len(os.listdir(file_location))
print('total class :',class_labels_cnt)
image_count = len(list(data_dir.glob('*/*.png')))

batch_size = 64
img_height = 224
img_width = 224
epochs = 500
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)


class_names = np.array(sorted([item.name for item in data_dir.glob('*[!.csv]') if item.name != "LICENSE.txt"]))
print('class_names:',class_names)
val_size = int(image_count * 0.2)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)


print(tf.data.experimental.cardinality(train_ds).numpy())
print(tf.data.experimental.cardinality(val_ds).numpy())

def get_label(file_path):
  # Convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  # The second to last is the class-directory
  one_hot = parts[-2] == class_names
  # Integer encode the label
  return tf.argmax(one_hot)


def decode_img(img):
  # Convert the compressed string to a 3D uint8 tensor
  img = tf.io.decode_jpeg(img, channels=3)
  # Resize the image to the desired size
  return tf.image.resize(img, [img_height, img_width])
  return img
def process_path(file_path):
  label = get_label(file_path)
  # Load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label
AUTOTUNE = tf.data.AUTOTUNE
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)

for image, label in train_ds.take(1):
  print("Image shape: ", image.numpy().shape)
  print("Label: ", label.numpy())

def configure_for_performance(ds):
  ds = ds.cache()
  ds = ds.shuffle(buffer_size=1000)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(buffer_size=AUTOTUNE)
  return ds
model = build_resnet_pretrained(input_shape=(224, 224, 3), no_classes=class_labels_cnt)

train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)



model.compile( loss='categorical_crossentropy',
               optimizer=Adam(amsgrad=True, decay=0.001 / epochs),
               metrics=['accuracy']
             )

history = model.fit(
    x=train_ds,
    epochs=epochs,
    steps_per_epoch=int(tf.data.experimental.cardinality(train_ds).numpy()/ batch_size),
    callbacks=callbacks,
    validation_data=val_ds,
    validation_steps=int(tf.data.experimental.cardinality(val_ds).numpy() / batch_size)
)

when I try to run this I endup getting error as below当我尝试运行它时,我最终收到如下错误

> WARNING:tensorflow:`period` argument is deprecated. Please use
> `save_freq` to specify the frequency in number of batches seen. Epoch
> 1/500 Traceback (most recent call last):   File
> "C:\Users\john\Desktop\John_data\Project\All_Scripts\Scripts\Deep_Learning\ResNet\tfdata_clsfn.py",
> line 113, in <module>
>     history = model.fit(   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py",
> line 1184, in fit
>     tmp_logs = self.train_function(iterator)   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 885, in __call__
>     result = self._call(*args, **kwds)   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 933, in _call
>     self._initialize(args, kwds, add_initializers_to=initializers)   File
> "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 759, in _initialize
>     self._stateful_fn._get_concrete_function_internal_garbage_collected( 
> # pylint: disable=protected-access   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\function.py",
> line 3066, in _get_concrete_function_internal_garbage_collected
>     graph_function, _ = self._maybe_define_function(args, kwargs)   File
> "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\function.py",
> line 3463, in _maybe_define_function
>     graph_function = self._create_graph_function(args, kwargs)   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\function.py",
> line 3298, in _create_graph_function
>     func_graph_module.func_graph_from_py_func(   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\framework\func_graph.py",
> line 1007, in func_graph_from_py_func
>     func_outputs = python_func(*func_args, **func_kwargs)   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 668, in wrapped_fn
>     out = weak_wrapped_fn().__wrapped__(*args, **kwds)   File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\framework\func_graph.py",
> line 994, in wrapper
>     raise e.ag_error_metadata.to_exception(e) ValueError: in user code:
> 
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:853
> train_function  *
>         return step_function(self, iterator)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:842
> step_function  **
>         outputs = model.distribute_strategy.run(run_step, args=(data,))
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1286
> run
>         return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2849
> call_for_each_replica
>         return self._call_for_each_replica(fn, args, kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3632
> _call_for_each_replica
>         return fn(*args, **kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:835
> run_step  **
>         outputs = model.train_step(data)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:788
> train_step
>         loss = self.compiled_loss(
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\compile_utils.py:201
> __call__
>         loss_value = loss_obj(y_t, y_p, sample_weight=sw)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\losses.py:141
> __call__
>         losses = call_fn(y_true, y_pred)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\losses.py:245
> call  **
>         return ag_fn(y_true, y_pred, **self._fn_kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\util\dispatch.py:206
> wrapper
>         return target(*args, **kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\losses.py:1665
> categorical_crossentropy
>         return backend.categorical_crossentropy(
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\util\dispatch.py:206
> wrapper
>         return target(*args, **kwargs)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\backend.py:4839
> categorical_crossentropy
>         target.shape.assert_is_compatible_with(output.shape)
>     C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\framework\tensor_shape.py:1161
> assert_is_compatible_with
>         raise ValueError("Shapes %s and %s are incompatible" % (self, other))
> 
>     ValueError: Shapes (None, 1) and (None, 1922) are incompatible
> 
> 
> Process finished with exit code 1

I'm not able to figureout the cause of the error , even though number of classes are same as output softmax layers Any help or suggestion to rectify this error will be appreciated very much我无法找出错误的原因,即使类的数量与输出 softmax 层数相同 任何纠正此错误的帮助或建议将不胜感激

我通过将损失函数更改为稀疏分类交叉熵来解决此问题

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM