[英]How can I preprocess a tf.data.Dataset using a provided preprocess_input function that expects a tf.Tensor?
Having a bit of a clueless moment, I'm looking to apply transfer learning to a problem using ResNet50 pre-trained on ImageNet.有点不知所措,我希望使用在ImageNet上预训练的 ResNet50 将迁移学习应用于问题。
I've got the transfer learning process all ready to go, but need my data set in the right form which tf.keras.applications.resnet50.preprocess_input
handily does.我已经准备好 go 的迁移学习过程,但需要我的数据集以正确的形式,而
tf.keras.applications.resnet50.preprocess_input
轻松完成。 Except it works on a numpy.array
or tf.Tensor
and I'm using image_dataset_from_directory
to load the data which gives me a tf.data.Dataset
.除了它适用于
numpy.array
或tf.Tensor
,我正在使用image_dataset_from_directory
加载数据,这给了我一个tf.data.Dataset
。
Is there a simple way to use the provided preprocess_input
function to preprocess my data in this form?有没有一种简单的方法可以使用提供的
preprocess_input
function 以这种形式预处理我的数据?
Alternatively, the function specifies:或者,function 指定:
The images are converted from RGB to BGR, then each color channel is zero-centered with respect to the ImageNet dataset, without scaling.
图像从 RGB 转换为 BGR,然后每个颜色通道相对于 ImageNet 数据集以零为中心,无需缩放。
So any other way to achieve this in the data pipeline or as part of the model would also be acceptable.因此,在数据管道中或作为 model 的一部分实现此目的的任何其他方式也是可以接受的。
You could use the map
function to apply the preprocess_input
function to your images:您可以使用
map
function 将preprocess_input
function 应用于您的图像:
import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
batch_size = 32
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(180, 180),
batch_size=batch_size)
def display(ds):
images, _ = next(iter(ds.take(1)))
image = images[0].numpy()
image /= 255.0
plt.imshow(image)
def preprocess(images, labels):
return tf.keras.applications.resnet50.preprocess_input(images), labels
train_ds = train_ds.map(preprocess)
display(train_ds)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.