![](/img/trans.png)
[英]Keras: flow_from_directory() or flow() using filenames instead of directories
[英]using Keras' flow_from_directory with FCNN
我成功地使用Keras训练了构造神经网络进行图像分割。 现在,我尝试通过对图像进行一些数据增强来提高性能。 为此,我使用ImageDataGenerator
,然后使用flow_from_directory
仅将批处理加载到内存中(我尝试了不带但出现内存错误)。 代码示例为:
training_images = np.array(training_images)
training_masks = np.array(training_masks)[:, :, :, 0].reshape(len(training_masks), 400, 400, 1)
# generators for data augmentation -------
seed = 1
generator_x = ImageDataGenerator(
featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=180,
horizontal_flip=True,
fill_mode='reflect')
generator_y = ImageDataGenerator(
featurewise_center=False,
featurewise_std_normalization=False,
rotation_range=180,
horizontal_flip=True,
fill_mode='reflect')
generator_x.fit(training_images, augment=True, seed=seed)
generator_y.fit(training_masks, augment=True, seed=seed)
image_generator = generator_x.flow_from_directory(
'data',
target_size=(400, 400),
class_mode=None,
seed=seed)
mask_generator = generator_y.flow_from_directory(
'masks',
target_size=(400, 400),
class_mode=None,
seed=seed)
train_generator = zip(image_generator, mask_generator)
model = unet(img_rows, img_cols)
model.fit_generator(train_generator, steps_per_epoch=int(len(training_images)/4), epochs=1)
但是,当我运行代码时,出现以下错误(我正在使用Tensorflow后端):
InvalidArgumentError (see above for traceback): Incompatible shapes: [14400000] vs. [4800000]
[[Node: loss/out_loss/mul = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"](loss/out_loss/Reshape, loss/out_loss/Reshape_1)]]
在错误中,它抱怨不兼容的形状为14400000(400x400x9)与4800000(400x400x3)。 我在这里使用自定义损失函数(如果您查看错误,它表示损失的某些方面)就是Dice系数,定义如下:
y_true_f = K.flatten(y_true)
y_pred_f = K.flatten(y_pred)
intersection = K.sum(y_true_f * y_pred_f)
return (2. * intersection + 1.) / (K.sum(y_true_f) + K.sum(y_pred_f) + 1.)
在这里,我将(400,400,3)带有遮罩的图像用于1类形状(400,400,1)。 我的NN的输入定义为Input((img_rows, img_cols, 3))
并输出为Conv2D(1, (1, 1), activation='sigmoid', name='out')(conv9)
(但这工作正常在没有数据扩充的情况下进行训练时)。
发生错误是因为您正在以RGB颜色模式读取蒙版。
默认color_mode
在flow_from_directory
为'rgb'
。 因此,无需指定color_mode
,您的遮罩将被加载到(batch_size, 400, 400, 3)
color_mode
(batch_size, 400, 400, 3)
数组中。 这就是为什么y_true_f
大于3倍y_pred_f
在您的错误信息。
要读取灰度蒙版,请使用color_mode='grayscale'
:
mask_generator = generator_y.flow_from_directory(
'masks',
target_size=(400, 400),
class_mode=None,
color_mode='grayscale',
seed=seed)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.