[英]How to train image classification model using tf.data pipeline?
我正在训练大班图像分类模型。 由于数据集很大,它不适合 ram 内存,所以我使用tf.data
管道将其存储在缓存中并在训练进行时读取。 我构建的整体代码如下
file_location = r'C:\imageclassification\data'
data_dir = pathlib.Path(file_location)
class_labels_cnt = len(os.listdir(file_location))
print('total class :',class_labels_cnt)
image_count = len(list(data_dir.glob('*/*.png')))
batch_size = 64
img_height = 224
img_width = 224
epochs = 500
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)
class_names = np.array(sorted([item.name for item in data_dir.glob('*[!.csv]') if item.name != "LICENSE.txt"]))
print('class_names:',class_names)
val_size = int(image_count * 0.2)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)
print(tf.data.experimental.cardinality(train_ds).numpy())
print(tf.data.experimental.cardinality(val_ds).numpy())
def get_label(file_path):
# Convert the path to a list of path components
parts = tf.strings.split(file_path, os.path.sep)
# The second to last is the class-directory
one_hot = parts[-2] == class_names
# Integer encode the label
return tf.argmax(one_hot)
def decode_img(img):
# Convert the compressed string to a 3D uint8 tensor
img = tf.io.decode_jpeg(img, channels=3)
# Resize the image to the desired size
return tf.image.resize(img, [img_height, img_width])
return img
def process_path(file_path):
label = get_label(file_path)
# Load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = decode_img(img)
return img, label
AUTOTUNE = tf.data.AUTOTUNE
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)
for image, label in train_ds.take(1):
print("Image shape: ", image.numpy().shape)
print("Label: ", label.numpy())
def configure_for_performance(ds):
ds = ds.cache()
ds = ds.shuffle(buffer_size=1000)
ds = ds.batch(batch_size)
ds = ds.prefetch(buffer_size=AUTOTUNE)
return ds
model = build_resnet_pretrained(input_shape=(224, 224, 3), no_classes=class_labels_cnt)
train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)
model.compile( loss='categorical_crossentropy',
optimizer=Adam(amsgrad=True, decay=0.001 / epochs),
metrics=['accuracy']
)
history = model.fit(
x=train_ds,
epochs=epochs,
steps_per_epoch=int(tf.data.experimental.cardinality(train_ds).numpy()/ batch_size),
callbacks=callbacks,
validation_data=val_ds,
validation_steps=int(tf.data.experimental.cardinality(val_ds).numpy() / batch_size)
)
当我尝试运行它时,我最终收到如下错误
> WARNING:tensorflow:`period` argument is deprecated. Please use
> `save_freq` to specify the frequency in number of batches seen. Epoch
> 1/500 Traceback (most recent call last): File
> "C:\Users\john\Desktop\John_data\Project\All_Scripts\Scripts\Deep_Learning\ResNet\tfdata_clsfn.py",
> line 113, in <module>
> history = model.fit( File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py",
> line 1184, in fit
> tmp_logs = self.train_function(iterator) File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 885, in __call__
> result = self._call(*args, **kwds) File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 933, in _call
> self._initialize(args, kwds, add_initializers_to=initializers) File
> "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 759, in _initialize
> self._stateful_fn._get_concrete_function_internal_garbage_collected(
> # pylint: disable=protected-access File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\function.py",
> line 3066, in _get_concrete_function_internal_garbage_collected
> graph_function, _ = self._maybe_define_function(args, kwargs) File
> "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\function.py",
> line 3463, in _maybe_define_function
> graph_function = self._create_graph_function(args, kwargs) File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\function.py",
> line 3298, in _create_graph_function
> func_graph_module.func_graph_from_py_func( File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\framework\func_graph.py",
> line 1007, in func_graph_from_py_func
> func_outputs = python_func(*func_args, **func_kwargs) File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\eager\def_function.py",
> line 668, in wrapped_fn
> out = weak_wrapped_fn().__wrapped__(*args, **kwds) File "C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\framework\func_graph.py",
> line 994, in wrapper
> raise e.ag_error_metadata.to_exception(e) ValueError: in user code:
>
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:853
> train_function *
> return step_function(self, iterator)
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:842
> step_function **
> outputs = model.distribute_strategy.run(run_step, args=(data,))
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1286
> run
> return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2849
> call_for_each_replica
> return self._call_for_each_replica(fn, args, kwargs)
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3632
> _call_for_each_replica
> return fn(*args, **kwargs)
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:835
> run_step **
> outputs = model.train_step(data)
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\training.py:788
> train_step
> loss = self.compiled_loss(
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\engine\compile_utils.py:201
> __call__
> loss_value = loss_obj(y_t, y_p, sample_weight=sw)
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\losses.py:141
> __call__
> losses = call_fn(y_true, y_pred)
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\losses.py:245
> call **
> return ag_fn(y_true, y_pred, **self._fn_kwargs)
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\util\dispatch.py:206
> wrapper
> return target(*args, **kwargs)
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\losses.py:1665
> categorical_crossentropy
> return backend.categorical_crossentropy(
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\util\dispatch.py:206
> wrapper
> return target(*args, **kwargs)
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\keras\backend.py:4839
> categorical_crossentropy
> target.shape.assert_is_compatible_with(output.shape)
> C:\Users\john\Anaconda3\envs\tf2_john\lib\site-packages\tensorflow\python\framework\tensor_shape.py:1161
> assert_is_compatible_with
> raise ValueError("Shapes %s and %s are incompatible" % (self, other))
>
> ValueError: Shapes (None, 1) and (None, 1922) are incompatible
>
>
> Process finished with exit code 1
我无法找出错误的原因,即使类的数量与输出 softmax 层数相同 任何纠正此错误的帮助或建议将不胜感激
我通过将损失函数更改为稀疏分类交叉熵来解决此问题
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.