繁体   English   中英

为什么我的 CNN 输入大小会导致 ValueError?

[英]Why does my CNN input size cause a ValueError?

我正在尝试使用以下架构在 Google Colab 中的灰度图像上训练类似 UNET 的 model:

from keras.models import Model
from keras.layers import Input, UpSampling2D
from keras.layers import Dropout
from keras.layers import Conv2D, Conv2DTranspose
from keras.layers import AveragePooling2D
from keras.layers import concatenate

inputs = Input((2048, 2048, 1))

c1 = Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (inputs)
c1 = Dropout(0.1) (c1)
c1 = Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c1)
p1 = AveragePooling2D((2, 2)) (c1)

c2 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p1)
c2 = Dropout(0.1) (c2)
c2 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c2)
p2 = AveragePooling2D((2, 2)) (c2)

c3 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p2)
c3 = Dropout(0.2) (c3)
c3 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c3)
p3 = AveragePooling2D((2, 2)) (c3)

c4 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p3)
c4 = Dropout(0.2) (c4)
c4 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c4)

u5 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same') (c4)
u5 = concatenate([u5, c3])
c5 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (u5)
c5 = Dropout(0.2) (c5)
c5 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c5)

u6 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same') (c5)
u6 = concatenate([u6, c2])
c6 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (u6)
c6 = Dropout(0.1) (c6)
c6 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c6)

u7 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same') (c6)
u7 = concatenate([u7, c1], axis=3)
c7 = Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (u7)
c7 = Dropout(0.1) (c7)
c7 = Conv2D(16, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c7)

outputs = Conv2D(1, (1, 1), activation='sigmoid') (c7)

model = Model(inputs=[inputs], outputs=[outputs])
# compile the model with RMSProp as optimizer, MSE as loss function and MAE as metric
model.compile(optimizer='rmsprop', loss='mean_squared_error', metrics=['mean_absolute_error'])
model.summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_7 (InputLayer)            (None, 2048, 2048, 1 0                                            
__________________________________________________________________________________________________
conv2d_34 (Conv2D)              (None, 2048, 2048, 1 160         input_7[0][0]                    
__________________________________________________________________________________________________
dropout_16 (Dropout)            (None, 2048, 2048, 1 0           conv2d_34[0][0]                  
__________________________________________________________________________________________________
conv2d_35 (Conv2D)              (None, 2048, 2048, 1 2320        dropout_16[0][0]                 
__________________________________________________________________________________________________
average_pooling2d_8 (AveragePoo (None, 1024, 1024, 1 0           conv2d_35[0][0]                  
__________________________________________________________________________________________________
conv2d_36 (Conv2D)              (None, 1024, 1024, 3 4640        average_pooling2d_8[0][0]        
__________________________________________________________________________________________________
dropout_17 (Dropout)            (None, 1024, 1024, 3 0           conv2d_36[0][0]                  
__________________________________________________________________________________________________
conv2d_37 (Conv2D)              (None, 1024, 1024, 3 9248        dropout_17[0][0]                 
__________________________________________________________________________________________________
average_pooling2d_9 (AveragePoo (None, 512, 512, 32) 0           conv2d_37[0][0]                  
__________________________________________________________________________________________________
conv2d_38 (Conv2D)              (None, 512, 512, 64) 18496       average_pooling2d_9[0][0]        
__________________________________________________________________________________________________
dropout_18 (Dropout)            (None, 512, 512, 64) 0           conv2d_38[0][0]                  
__________________________________________________________________________________________________
conv2d_39 (Conv2D)              (None, 512, 512, 64) 36928       dropout_18[0][0]                 
__________________________________________________________________________________________________
average_pooling2d_10 (AveragePo (None, 256, 256, 64) 0           conv2d_39[0][0]                  
__________________________________________________________________________________________________
conv2d_40 (Conv2D)              (None, 256, 256, 128 73856       average_pooling2d_10[0][0]       
__________________________________________________________________________________________________
dropout_19 (Dropout)            (None, 256, 256, 128 0           conv2d_40[0][0]                  
__________________________________________________________________________________________________
conv2d_41 (Conv2D)              (None, 256, 256, 128 147584      dropout_19[0][0]                 
__________________________________________________________________________________________________
conv2d_transpose_7 (Conv2DTrans (None, 512, 512, 64) 32832       conv2d_41[0][0]                  
__________________________________________________________________________________________________
concatenate_7 (Concatenate)     (None, 512, 512, 128 0           conv2d_transpose_7[0][0]         
                                                                 conv2d_39[0][0]                  
__________________________________________________________________________________________________
conv2d_42 (Conv2D)              (None, 512, 512, 64) 73792       concatenate_7[0][0]              
__________________________________________________________________________________________________
dropout_20 (Dropout)            (None, 512, 512, 64) 0           conv2d_42[0][0]                  
__________________________________________________________________________________________________
conv2d_43 (Conv2D)              (None, 512, 512, 64) 36928       dropout_20[0][0]                 
__________________________________________________________________________________________________
conv2d_transpose_8 (Conv2DTrans (None, 1024, 1024, 3 8224        conv2d_43[0][0]                  
__________________________________________________________________________________________________
concatenate_8 (Concatenate)     (None, 1024, 1024, 6 0           conv2d_transpose_8[0][0]         
                                                                 conv2d_37[0][0]                  
__________________________________________________________________________________________________
conv2d_44 (Conv2D)              (None, 1024, 1024, 3 18464       concatenate_8[0][0]              
__________________________________________________________________________________________________
dropout_21 (Dropout)            (None, 1024, 1024, 3 0           conv2d_44[0][0]                  
__________________________________________________________________________________________________
conv2d_45 (Conv2D)              (None, 1024, 1024, 3 9248        dropout_21[0][0]                 
__________________________________________________________________________________________________
conv2d_transpose_9 (Conv2DTrans (None, 2048, 2048, 1 2064        conv2d_45[0][0]                  
__________________________________________________________________________________________________
concatenate_9 (Concatenate)     (None, 2048, 2048, 3 0           conv2d_transpose_9[0][0]         
                                                                 conv2d_35[0][0]                  
__________________________________________________________________________________________________
conv2d_46 (Conv2D)              (None, 2048, 2048, 1 4624        concatenate_9[0][0]              
__________________________________________________________________________________________________
dropout_22 (Dropout)            (None, 2048, 2048, 1 0           conv2d_46[0][0]                  
__________________________________________________________________________________________________
conv2d_47 (Conv2D)              (None, 2048, 2048, 1 2320        dropout_22[0][0]                 
__________________________________________________________________________________________________
conv2d_48 (Conv2D)              (None, 2048, 2048, 1 17          conv2d_47[0][0]                  
==================================================================================================
Total params: 481,745
Trainable params: 481,745
Non-trainable params: 0

我在使用 tensorflow 版本 2 时遇到了错误(“没有算法工作”),所以我根据其他答案切换到版本 1,但现在我收到输入错误:

/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name + '` call to the ' +
     90                               'Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1730             use_multiprocessing=use_multiprocessing,
   1731             shuffle=shuffle,
-> 1732             initial_epoch=initial_epoch)
   1733 
   1734     @interfaces.legacy_generator_methods_support

/usr/local/lib/python3.6/dist-packages/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    218                                             sample_weight=sample_weight,
    219                                             class_weight=class_weight,
--> 220                                             reset_metrics=False)
    221 
    222                 outs = to_list(outs)

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics)
   1506             x, y,
   1507             sample_weight=sample_weight,
-> 1508             class_weight=class_weight)
   1509         if self._uses_dynamic_learning_phase():
   1510             ins = x + y + sample_weights + [1]

/usr/local/lib/python3.6/dist-packages/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, check_array_lengths, batch_size)
    577             feed_input_shapes,
    578             check_batch_axis=False,  # Don't enforce the batch size.
--> 579             exception_prefix='input')
    580 
    581         if y is not None:

/usr/local/lib/python3.6/dist-packages/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    143                             ': expected ' + names[i] + ' to have shape ' +
    144                             str(shape) + ' but got array with shape ' +
--> 145                             str(data_shape))
    146     return data
    147 

ValueError: Error when checking input: expected input_7 to have shape (2048, 2048, 1) but got array with shape (256, 256, 3)

我怀疑 ValueError 是由于输入形状通常为 256x256,但我认为只需更改输入形状即可。 我希望我需要再添加一两个卷积层才能获得我想要的结果,但现在我只想让 CNN 开始处理我的数据。 我正在使用flow_from_directory加载文件,所以我知道我不能只改变训练数据的形状。 如何修复错误?

我是否需要更改网络中的任何内容以补偿更大的输入?

我想你的数据有维度 (256, 256, 3) 所以你必须调整你的网络结构。

尝试以这种方式更改输入层inputs = Input((256, 256, 3)) 这可能会导致其他更改,例如池大小

你的 output 层也必须有 3 个通道: outputs = Conv2D(3, (1, 1), activation='sigmoid')

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM