[英]Input 0 of layer dense is incompatible with the layer: expected axis -1 of input shape to have value 8192 but received input with shape (None, 61608)
I am trying to create an image processing CNN.我正在尝试创建一个图像处理 CNN。 I am using VGG16 to speed up some of the learning process.我正在使用 VGG16 来加快一些学习过程。 The creation of my CNN below works to the point of training and saving the model & weights.下面创建我的 CNN 可以训练和保存 model 和权重。 The issue occurs when I try to run a predict function after loading in the model.当我在加载 model 后尝试运行预测 function 时会出现此问题。
image_gen = ImageDataGenerator()
train = image_gen.flow_from_directory('./data/train', class_mode='categorical', shuffle=False, batch_size=10, target_size=(151, 136))
val = image_gen.flow_from_directory('./data/validate', class_mode='categorical', shuffle=False, batch_size=10, target_size=(151, 136))
pretrained_model = VGG16(include_top=False, input_shape=(151, 136, 3), weights='imagenet')
pretrained_model.summary()
vgg_features_train = pretrained_model.predict(train)
vgg_features_val = pretrained_model.predict(val)
train_target = to_categorical(train.labels)
val_target = to_categorical(val.labels)
model = Sequential()
model.add(Flatten())
model.add(Dense(100, activation='relu'))
model.add(Dropout(0.5))
model.add(BatchNormalization())
model.add(Dense(2, activation='softmax'))
model.compile(optimizer='rmsprop', metrics=['accuracy'], loss='categorical_crossentropy')
target_dir = './models/weights-improvement'
if not os.path.exists(target_dir):
os.mkdir(target_dir)
checkpoint = ModelCheckpoint(filepath=target_dir + 'weights-improvement-{epoch:02d}-{val_accuracy:.2f}.hdf5', monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]
model.fit(vgg_features_train, train_target, epochs=100, batch_size=8, validation_data=(vgg_features_val, val_target), callbacks=callbacks_list)
model.save('./models/model')
model.save_weights('./models/weights')
I have this predict function, that I would like to load in an image, and then return the categorisation of this image that the model gives.我有这个预测 function,我想加载到一个图像中,然后返回 model 给出的这个图像的分类。
from keras.preprocessing.image import load_img, img_to_array
def predict(file):
x = load_img(file, target_size=(151,136,3))
x = img_to_array(x)
print(x.shape)
print(x.shape)
x = np.expand_dims(x, axis=0)
array = model.predict(x)
result = array[0]
if result[0] > result[1]:
if result[0] > 0.9:
print("Predicted answer: Buy")
answer = 'buy'
print(result)
print(array)
else:
print("Predicted answer: Not confident")
answer = 'n/a'
print(result)
else:
if result[1] > 0.9:
print("Predicted answer: Sell")
answer = 'sell'
print(result)
else:
print("Predicted answer: Not confident")
answer = 'n/a'
print(result)
return answer
The issue I am experiencing is when I run this predict function, I get the following error.我遇到的问题是当我运行这个预测 function 时,我收到以下错误。
File "predict-binary.py", line 24, in predict
array = model.predict(x)
File ".venv\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1629, in predict
tmp_batch_outputs = self.predict_function(iterator)
File ".venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 828, in __call__
result = self._call(*args, **kwds)
File ".venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 871, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File ".venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 725, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File ".venv\lib\site-packages\tensorflow\python\eager\function.py", line 2969, in _get_concrete_function_internal_garbage_collected
graph_function, _ = self._maybe_define_function(args, kwargs)
File ".venv\lib\site-packages\tensorflow\python\eager\function.py", line 3361, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File ".venv\lib\site-packages\tensorflow\python\eager\function.py", line 3196, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File ".venv\lib\site-packages\tensorflow\python\framework\func_graph.py", line 990, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File ".venv\lib\site-packages\tensorflow\python\eager\def_function.py", line 634, in wrapped_fn
out = weak_wrapped_fn().__wrapped__(*args, **kwds)
File ".venv\lib\site-packages\tensorflow\python\framework\func_graph.py", line 977, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
.venv\lib\site-packages\tensorflow\python\keras\engine\training.py:1478 predict_function *
return step_function(self, iterator)
.venv\lib\site-packages\tensorflow\python\keras\engine\training.py:1468 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
.venv\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:1259 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
.venv\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:2730 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
.venv\lib\site-packages\tensorflow\python\distribute\distribute_lib.py:3417 _call_for_each_replica
return fn(*args, **kwargs)
.venv\lib\site-packages\tensorflow\python\keras\engine\training.py:1461 run_step **
outputs = model.predict_step(data)
.venv\lib\site-packages\tensorflow\python\keras\engine\training.py:1434 predict_step
return self(x, training=False)
.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:1012 __call__
outputs = call_fn(inputs, *args, **kwargs)
.venv\lib\site-packages\tensorflow\python\keras\engine\sequential.py:375 call
return super(Sequential, self).call(inputs, training=training, mask=mask)
.venv\lib\site-packages\tensorflow\python\keras\engine\functional.py:424 call
return self._run_internal_graph(
.venv\lib\site-packages\tensorflow\python\keras\engine\functional.py:560 _run_internal_graph
outputs = node.layer(*args, **kwargs)
.venv\lib\site-packages\tensorflow\python\keras\engine\base_layer.py:998 __call__
input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
.venv\lib\site-packages\tensorflow\python\keras\engine\input_spec.py:255 assert_input_compatibility
raise ValueError(
ValueError: Input 0 of layer dense is incompatible with the layer: expected axis -1 of input shape to have value 8192 but received input with shape (None, 61608)
I'm assuming I need to change something between the Flatten()
and Dense()
layers of my model, but I'm not sure what.我假设我需要在 model 的Flatten()
和Dense()
层之间进行一些更改,但我不确定是什么。 I attempted to add model.add(Dense(61608, activation='relu))
between these two as that seemed to be what was suggested in another post I saw (cannot find link now), but it lead to the same error.我试图在这两者之间添加model.add(Dense(61608, activation='relu))
因为这似乎是我在另一篇文章中看到的建议(现在找不到链接),但它导致了同样的错误。 (I tried it with 8192 instead of 61608 as well). (我也尝试使用 8192 而不是 61608)。 Any help is appreciated, thanks.任何帮助表示赞赏,谢谢。
EDIT #1:编辑#1:
Changing the model creation/training code as I think it was suggested by Gerry P to this更改 model 创建/培训代码,因为我认为Gerry P对此提出了建议
img_shape = (151,136,3)
base_model=VGG19( include_top=False, input_shape=img_shape, pooling='max', weights='imagenet' )
x=base_model.output
x=Dense(100, activation='relu')(x)
x=Dropout(0.5)(x)
x=BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001)(x)
output=Dense(2, activation='softmax')(x)
model=Model(inputs=base_model.input, outputs=output)
image_gen = ImageDataGenerator()
train = image_gen.flow_from_directory('./data/train', class_mode='categorical', shuffle=False, batch_size=10, target_size=(151, 136))
val = image_gen.flow_from_directory('./data/validate', class_mode='categorical', shuffle=False, batch_size=10, target_size=(151, 136))
vgg_features_train = base_model.predict(train)
vgg_features_val = base_model.predict(val)
train_target = to_categorical(train.labels)
val_target = to_categorical(val.labels)
model.compile(optimizer='rmsprop', metrics=['accuracy'], loss='categorical_crossentropy')
model.fit(vgg_features_train, train_target, epochs=100, batch_size=8, validation_data=(vgg_features_val, val_target), callbacks=callbacks_list)
This resulted in a different input shape error of File "train-binary.py", line 37, in <module> model.fit(vgg_features_train, train_target, epochs=100, batch_size=8, validation_data=(vgg_features_val, val_target), callbacks=callbacks_list) ValueError: Input 0 is incompatible with layer model: expected shape=(None, 151, 136, 3), found shape=(None, 512)
这导致File "train-binary.py", line 37, in <module> model.fit(vgg_features_train, train_target, epochs=100, batch_size=8, validation_data=(vgg_features_val, val_target), callbacks=callbacks_list) ValueError: Input 0 is incompatible with layer model: expected shape=(None, 151, 136, 3), found shape=(None, 512)
的不同输入形状错误File "train-binary.py", line 37, in <module> model.fit(vgg_features_train, train_target, epochs=100, batch_size=8, validation_data=(vgg_features_val, val_target), callbacks=callbacks_list) ValueError: Input 0 is incompatible with layer model: expected shape=(None, 151, 136, 3), found shape=(None, 512)
your model is expecting to see an input for model.predict that has the same dimensions as it was trained on.您的 model 期望看到 model.predict 的输入,其尺寸与训练时的尺寸相同。 In this case it is the dimensions of vgg_features_train.The input to model.predict that you are generating is for the input to the VGG model.在这种情况下,它是 vgg_features_train 的尺寸。您生成的 model.predict 的输入是 VGG model 的输入。 You are essentially trying to do transfer learning so I suggest you proceed as below您实际上是在尝试进行迁移学习,因此我建议您按照以下步骤进行
base_model=tf.keras.applications.VGG19( include_top=False, input_shape=img_shape, pooling='max', weights='imagenet' )
x=base_model.output
x=Dense(100, activation='relu'))(x)
x=Dropout(0.5)(x)
x=BatchNormalization(axis=-1, momentum=0.99, epsilon=0.001)(x)
output=Dense(2, activation='softmax')(x)
model=Model(inputs=base_model.input, outputs=output)
model.fit( train, epochs=100, batch_size=8, validation_data=val, callbacks=callbacks_list)
now for prediction you can use the same dimensions as you used to train the model.现在进行预测,您可以使用与训练 model 相同的尺寸。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.