繁体   English   中英

如何使用提供的需要 tf.Tensor 的 preprocess_input function 预处理 tf.data.Dataset?

[英]How can I preprocess a tf.data.Dataset using a provided preprocess_input function that expects a tf.Tensor?

有点不知所措,我希望使用在ImageNet上预训练的 ResNet50 将迁移学习应用于问题。

我已经准备好 go 的迁移学习过程,但需要我的数据集以正确的形式,而tf.keras.applications.resnet50.preprocess_input轻松完成。 除了它适用于numpy.arraytf.Tensor ,我正在使用image_dataset_from_directory加载数据,这给了我一个tf.data.Dataset

有没有一种简单的方法可以使用提供的preprocess_input function 以这种形式预处理我的数据?

或者,function 指定:

图像从 RGB 转换为 BGR,然后每个颜色通道相对于 ImageNet 数据集以零为中心,无需缩放。

因此,在数据管道中或作为 model 的一部分实现此目的的任何其他方式也是可以接受的。

您可以使用map function 将preprocess_input function 应用于您的图像:

import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 32

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(180, 180),
  batch_size=batch_size)

def display(ds):
  images, _ = next(iter(ds.take(1)))
  image = images[0].numpy()
  image /= 255.0
  plt.imshow(image)

def preprocess(images, labels):
  return tf.keras.applications.resnet50.preprocess_input(images), labels

train_ds = train_ds.map(preprocess)

display(train_ds)

在此处输入图像描述

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM