简体   繁体   English

如何使用提供的需要 tf.Tensor 的 preprocess_input function 预处理 tf.data.Dataset?

[英]How can I preprocess a tf.data.Dataset using a provided preprocess_input function that expects a tf.Tensor?

Having a bit of a clueless moment, I'm looking to apply transfer learning to a problem using ResNet50 pre-trained on ImageNet.有点不知所措,我希望使用在ImageNet上预训练的 ResNet50 将迁移学习应用于问题。

I've got the transfer learning process all ready to go, but need my data set in the right form which tf.keras.applications.resnet50.preprocess_input handily does.我已经准备好 go 的迁移学习过程,但需要我的数据集以正确的形式,而tf.keras.applications.resnet50.preprocess_input轻松完成。 Except it works on a numpy.array or tf.Tensor and I'm using image_dataset_from_directory to load the data which gives me a tf.data.Dataset .除了它适用于numpy.arraytf.Tensor ,我正在使用image_dataset_from_directory加载数据,这给了我一个tf.data.Dataset

Is there a simple way to use the provided preprocess_input function to preprocess my data in this form?有没有一种简单的方法可以使用提供的preprocess_input function 以这种形式预处理我的数据?

Alternatively, the function specifies:或者,function 指定:

The images are converted from RGB to BGR, then each color channel is zero-centered with respect to the ImageNet dataset, without scaling.图像从 RGB 转换为 BGR,然后每个颜色通道相对于 ImageNet 数据集以零为中心,无需缩放。

So any other way to achieve this in the data pipeline or as part of the model would also be acceptable.因此,在数据管道中或作为 model 的一部分实现此目的的任何其他方式也是可以接受的。

You could use the map function to apply the preprocess_input function to your images:您可以使用map function 将preprocess_input function 应用于您的图像:

import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 32

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(180, 180),
  batch_size=batch_size)

def display(ds):
  images, _ = next(iter(ds.take(1)))
  image = images[0].numpy()
  image /= 255.0
  plt.imshow(image)

def preprocess(images, labels):
  return tf.keras.applications.resnet50.preprocess_input(images), labels

train_ds = train_ds.map(preprocess)

display(train_ds)

在此处输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 使用 `tf.Tensor` 作为 Python `bool` 在使用 `tf.data.Dataset` 的图形执行中是不允许的 - using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution with `tf.data.Dataset` 如何按特定值过滤 tf.data.Dataset? - How can I filter tf.data.Dataset by specific values? 如何从 tf.data.Dataset 中拆分 output? - How can I split output from the tf.data.Dataset? 如何将 tf.data.dataset object 加载到自动编码器中? - How can I load tf.data.dataset object into an autoencoder? 如何理解每种模型在keras中的`preprocess_input`函数? - How to understand the `preprocess_input` function in keras for each model? 如何使用 SessionRunHook 通过 tf.data.Dataset API 打印张量? - How to use SessionRunHook to print tensor with tf.data.Dataset API? 如何加入两个 tf.data.Dataset 张量切片? - How to join two tf.data.Dataset tensor slices? 如何使用 tf.data.Dataset 对象的 map 或 filter 或 reduce 方法修改序列数据? - How can I modifya sequencial data using map or filter or reduce method for tf.data.Dataset objects? 如何将 map function 应用于 tf.Tensor - how to apply map function to the tf.Tensor 如何在用@tf.function 装饰的函数内部使用 for 循环来操作和返回 tf.Variable tf.data.Dataset? - How to manipulate and return tf.Variable using a for loop over tf.data.Dataset inside function decorated with @tf.function?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM