繁体   English   中英

在 Tensorflow 2.3 中规范化 BatchDataset

[英]Normalizing BatchDataset in Tensorflow 2.3

我正在使用 TF 2.3 中的tf.keras.preprocessing.image_dataset_from_directory从目录(训练/测试拆分)加载图像。 我得到的是一个tf.data.Dataset( tensorflow.python.data.ops.dataset_ops.BatchDataset与形状实际上)对象:

train_ds.take(1)
# <TakeDataset shapes: ((None, 256, 256, 3), (None, 6)), types: (tf.float32, tf.float32)>
for images, labels in train_ds.take(1):
    print(images.shape)
    print(images[0])
# (32, 256, 256, 3)
# tf.Tensor(
# [[[225.75  225.75  225.75 ]
#   [225.75  225.75  225.75 ]
#   [225.75  225.75  225.75 ]
#   ...
#   [215.    214.    209.   ]
#   [215.    214.    209.   ]
#   [215.    214.    209.   ]]
#
#  ...], shape=(256, 256, 3), dtype=float32)

我无法弄清楚如何使用该 Dataset 对象标准化图像( /= 255 )。 我试图打/=运营商本身, mapapply方法,甚至流延对象名单提到这里 似乎没有任何效果,我真的很想在数据集级别解决这个问题,而不是向我的网络添加规范化层。

有任何想法吗?

试试这个方法:

def process(image,label):
    image = tf.cast(image/255. ,tf.float32)
    return image,label

ds = tf.keras.preprocessing.image_dataset_from_directory(IMAGE_DIR)
ds = ds.map(process)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM