简体   繁体   English

如何在 tensorflow 数据集上使用 mobilenet_v2.preprocess_input

[英]How to use mobilenet_v2.preprocess_input on tensorflow dataset

I'm again struggling with the usage of tensorflow datasets.我再次为 tensorflow 数据集的使用而苦苦挣扎。 I'm again loading my images via我再次通过

data = keras.preprocessing.image_dataset_from_directory(
  './data', 
  labels='inferred', 
  label_mode='binary', 
  validation_split=0.2, 
  subset="training", 
  image_size=(img_height, img_width), 
  batch_size=sz_batch, 
  crop_to_aspect_ratio=True
)

I want to use this dataset in the pre-trained MobileNetV2我想在预训练的 MobileNetV2 中使用这个数据集

model = keras.applications.mobilenet_v2.MobileNetV2(input_shape=(img_height, img_width, 3), weights='imagenet')

The documentation says, that the input data must be scaled to be between -1 and 1 ( https://www.tensorflow.org/api_docs/python/tf/keras/applications/mobilenet_v2/MobileNetV2 ).文档说,输入数据必须缩放到 -1 和 1 之间( https://www.tensorflow.org/api_docs/python/tf/keras/applications/mobilenet_v2/MobileNetV2 )。 To do so, the preprocess_input function is provided.为此,提供了preprocess_input function。 When I use this function on my dataset当我在我的数据集上使用这个 function

scaled_data = tf.keras.applications.mobilenet_v2.preprocess_input(data)

I get the error: TypeError: unsupported operand type(s) for /=: 'BatchDataset' and 'float'我收到错误: TypeError: unsupported operand type(s) for /=: 'BatchDataset' and 'float'

So how can I use this function properly with the tensorflow dataset?那么如何将这个 function 与 tensorflow 数据集正确使用?

Maybe try using tf.data.Dataset.map :也许尝试使用tf.data.Dataset.map

import tensorflow as tf
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 32

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(180, 180),
  batch_size=batch_size)

def preprocess(images, labels):
  return tf.keras.applications.mobilenet_v2.preprocess_input(images), labels

train_ds = train_ds.map(preprocess)

images, _ = next(iter(train_ds.take(1)))
image = images[0]
plt.imshow(image.numpy())

Before preprocessing the images:在预处理图像之前:

在此处输入图像描述

After preprocessing the images with tf.keras.applications.mobilenet_v2.preprocess_input only:仅使用tf.keras.applications.mobilenet_v2.preprocess_input预处理图像后:

在此处输入图像描述

After preprocessing the images with tf.keras.layers.Rescaling(1./255) only:仅使用tf.keras.layers.Rescaling(1./255)预处理图像后: 在此处输入图像描述

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 tensorflow中的MobileNet预处理输入如何 - How is the MobileNet preprocess input in tensorflow 如何预处理 tensorflow imdb_review 数据集 - How to preprocess tensorflow imdb_review dataset 如何预处理 Tensorflow 2.x 中实现的 BERT model 的数据集? - How to preprocess a dataset for BERT model implemented in Tensorflow 2.x? 使用 Tensorflow 构建 RNN。 如何正确预处理我的数据集以匹配 RNN 的输入和 output 形状? - Building RNN with Tensorflow. How do I preprocess my dataset correctly to match the RNN's input and output shape? 如何在Tensorflow V2中将tf.data.Dataset与图表一起使用? - How to use `tf.data.Dataset` in Tensorflow V2 with graphs? 使用Tensorflow Dataset API读取TFRecords文件时,预处理输入数据会减慢输入管道的速度 - Preprocess the input data slow down the input pipeline when using Tensorflow Dataset API to read TFRecords file TensorFlow 如何计算 vgg19.preprocess_input 的梯度? - How does TensorFlow compute the gradient of vgg19.preprocess_input? 如何使用单个数据集在 tensorflow keras 中训练多个输入 model - how to use a single dataset to train multiple input model in tensorflow keras 如何正确使用 tensorflow 数据集用于具有 keras 的多个输入层 - How to use tensorflow dataset correctly for multiple input layers with keras 如何使用提供的需要 tf.Tensor 的 preprocess_input function 预处理 tf.data.Dataset? - How can I preprocess a tf.data.Dataset using a provided preprocess_input function that expects a tf.Tensor?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM