Having a bit of a clueless moment, I'm looking to apply transfer learning to a problem using ResNet50 pre-trained on ImageNet.
I've got the transfer learning process all ready to go, but need my data set in the right form which tf.keras.applications.resnet50.preprocess_input
handily does. Except it works on a numpy.array
or tf.Tensor
and I'm using image_dataset_from_directory
to load the data which gives me a tf.data.Dataset
.
Is there a simple way to use the provided preprocess_input
function to preprocess my data in this form?
Alternatively, the function specifies:
The images are converted from RGB to BGR, then each color channel is zero-centered with respect to the ImageNet dataset, without scaling.
So any other way to achieve this in the data pipeline or as part of the model would also be acceptable.
You could use the map
function to apply the preprocess_input
function to your images:
import tensorflow as tf
import pathlib
import matplotlib.pyplot as plt
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)
batch_size = 32
train_ds = tf.keras.utils.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(180, 180),
batch_size=batch_size)
def display(ds):
images, _ = next(iter(ds.take(1)))
image = images[0].numpy()
image /= 255.0
plt.imshow(image)
def preprocess(images, labels):
return tf.keras.applications.resnet50.preprocess_input(images), labels
train_ds = train_ds.map(preprocess)
display(train_ds)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.