[英]Data Augmentation on tf.dataset.Dataset
为了使用 Google Colabs TPU,我需要一个tf.dataset.Dataset
。 那么如何在这样的数据集上使用数据增强?
更具体地说,到目前为止我的代码是:
def get_dataset(batch_size=200):
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
try_gcs=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255.0
label = tf.one_hot(label,10)
return image, label
train_dataset = mnist_train.map(scale).shuffle(10000).batch(batch_size)
test_dataset = mnist_test.map(scale).batch(batch_size)
return train_dataset, test_dataset
这被送入:
# TPU Strategy ...
with strategy.scope():
model = create_model()
model.compile(loss="categorical_crossentropy",
optimizer="adam",
metrics=["acc"])
train_dataset, test_dataset = get_dataset()
model.fit(train_dataset,
epochs=20,
verbose=1,
validation_data=test_dataset)
那么,我如何在这里使用数据增强呢? 据我所知,我不能使用 tf.keras ImageDataGenerator,对吧?
我尝试了以下方法,但没有奏效。
data_generator = ...
model.fit_generator(data_generator.flow(train_dataset, batch_size=32),
steps_per_epoch=len(train_dataset) / 32, epochs=20)
这并不奇怪,因为通常 train_x 和 train_y 作为两个 arguments 馈送到流 function 中,而不是“打包”到一个tf.dataset.Dataset
中。
您可以使用tf.image函数。 tf.image
模块包含用于图像处理的各种功能。
例如:
您可以在 function def get_dataset
中添加以下功能。
0-1
范围内的tf.float64
。cache()
结果,因为这些结果可以在每次repeat
后重复使用random_flip_left_right
随机翻转 left_to_right 每个图像。random_contrast
随机改变图像的对比度。repeat
重复所有步骤,图像数量增加了两倍。代码 -
mnist_train = mnist_train.map(
lambda image, label: (tf.image.convert_image_dtype(image, tf.float32), label)
).cache(
).map(
lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
1000
).
batch(
batch_size
).repeat(2)
同样,您可以使用其他功能,如random_flip_up_down
、 random_crop
函数来随机垂直翻转图像(上下颠倒)并分别将张量随机裁剪为给定大小。
您的get_dataset
function 将如下所示 -
def get_dataset(batch_size=200):
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True,
try_gcs=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
train_dataset = mnist_train.map(
lambda image, label: (tf.image.convert_image_dtype(image, tf.float32),label)
).cache(
).map(
lambda image, label: (tf.image.random_flip_left_right(image), label)
).map(
lambda image, label: (tf.image.random_contrast(image, lower=0.0, upper=1.0), label)
).shuffle(
1000
).batch(
batch_size
).repeat(2)
test_dataset = mnist_test.map(scale).batch(batch_size)
return train_dataset, test_dataset
添加@Andrew H 建议的链接,该链接提供了同样使用mnist
数据集的数据增强的端到端示例。
希望这能回答你的问题。 快乐学习。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.