[英]How to random_crop an unlabeled tensorflow Dataset? ValueError: Dimensions must be equal, but are 4 and 3
我正在尝试在使用 tensorflow 数据集加载图像时增加(随机裁剪)图像。 当我在映射的 function 中调用方法tf.image.random_crop时出现此错误:
ValueError: Dimensions must be equal, but are 4 and 3 for '{{node random_crop/GreaterEqual}} = GreaterEqual[T=DT_INT32](random_crop/Shape, random_crop/size)' with input shapes: [4], [3].
为了重现错误,只需在目录中放置一些 png 图像:
./img/class0/
然后运行这段代码:
import os
import tensorflow as tf
train_set_raw = tf.keras.preprocessing.image_dataset_from_directory('./img',label_mode=None,validation_split=None,batch_size=32)
def augment(tensor):
tensor = tf.cast(x=tensor, dtype=tf.float32)
tensor = tf.divide(x=tensor, y=tf.constant(255.))
tensor = tf.image.random_crop(value=tensor, size=(256, 256, 3))
return tensor
train_set_raw = train_set_raw.map(augment).batch(32)
如果我明确指定批量大小,
tensor = tf.image.random_crop(value=tensor, size=(32,256, 256, 3))
可以对错误进行排序。 但是,如果您尝试将 model 与使用固定批量大小创建的数据集拟合,您将收到错误消息:
tensorflow.python.framework.errors_impl.InvalidArgumentError: assertion failed: [Need value.shape >= size, got ] [1 256 256 3] [32 256 256 3]
[[{{node random_crop/Assert/Assert}}]]
尝试使用 1 的批量大小:
tensor = tf.image.random_crop(value=tensor, size=(1,256, 256, 3))
但我认为您不应该将高级数据加载器与低级tf.data.Dataset
。 尝试仅使用后者。
import tensorflow as tf
image_dir = r'C:\Users\user\Pictures'
files = tf.data.Dataset.list_files(image_dir + '\\*jpg')
def load(filepath):
image = tf.io.read_file(filepath)
image = tf.image.decode_image(image)
return image
ds = files.map(load)
def augment(tensor):
tensor = tf.cast(x=tensor, dtype=tf.float32)
tensor = tf.divide(x=tensor, y=tf.constant(255.))
tensor = tf.image.random_crop(value=tensor, size=(100, 100, 3))
random_target = tf.random.uniform((1,), dtype=tf.int32, maxval=2)
return tensor, random_target
train_set_raw = ds.map(augment).batch(32)
model = tf.keras.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(8, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid')
])
model.compile(loss='binary_crossentropy', optimizer='adam')
history = model.fit(train_set_raw)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.