简体   繁体   English

TypeError: ('Keyword argument not understand:', 'input') for CNN with pretrained model from VGG16

[英]TypeError: ('Keyword argument not understood:', 'input') for CNN with pretrained model from VGG16

Hi all I am learning CNN models and I am running this notebook link https://www.kaggle.com/sureshhpba05/final-work大家好,我正在学习 CNN 模型,我正在运行这个笔记本链接https://www.kaggle.com/sureshhpba05/final-work

While I am running this set of codes to load the VGG16 model当我运行这组代码来加载 VGG16 model


def VGG16_MODEL(img_rows=IMG_SIZE, img_cols=IMG_SIZE, color_type=3):
    # Remove fully connected layer and replace
    # with softmax for classifying 10 classes
    model_vgg16_1 = VGG16(weights="imagenet", include_top=False)

    # Freeze all layers of the pre-trained model
    for layer in model_vgg16_1.layers:
        layer.trainable = False
        
    x = model_vgg16_1.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation='relu')(x)
    predictions = Dense(CLASSES, activation = 'softmax')(x)

    model = Model(input = model_vgg16_1.input, output = predictions)
    
    return model

print("Model 1 network...")
model_vgg16_1 = VGG16_MODEL(img_rows=IMG_SIZE, img_cols=IMG_SIZE)

model_vgg16_1.summary()

model_vgg16_1.compile(loss='categorical_crossentropy',
                         optimizer='rmsprop',
                         metrics=['accuracy'])

I got this error我收到了这个错误

TypeError: ('Keyword argument not understood:', 'input')

Can anyone tell me why I get this error and how should I rectify this?谁能告诉我为什么会出现这个错误以及我应该如何纠正这个错误?

For the benefit of community providing complete working code here为了社区的利益,在这里提供完整的工作代码

import tensorflow as tf
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras import Model
from tensorflow.keras.applications import VGG16

IMG_SIZE = 224
CLASSES = 10

def VGG16_MODEL(img_rows=IMG_SIZE, img_cols=IMG_SIZE, color_type=3):


    # Remove fully connected layer and replace
    # with softmax for classifying 10 classes
 model_vgg16_1 = VGG16(weights="imagenet", include_top=False)

    # Freeze all layers of the pre-trained model
 for layer in model_vgg16_1.layers:


  layer.trainable = False
        
  x = model_vgg16_1.output
  x = GlobalAveragePooling2D()(x)
  x = Dense(1024, activation='relu')(x)
  predictions = Dense(CLASSES, activation = 'softmax')(x)

  model = Model(inputs = model_vgg16_1.input, outputs = predictions)
    
  return model

print("Model 1 network...")
model_vgg16_1 = VGG16_MODEL(img_rows=IMG_SIZE, img_cols=IMG_SIZE)

model_vgg16_1.summary()

model_vgg16_1.compile(loss='categorical_crossentropy',
                         optimizer='rmsprop',
                         metrics=['accuracy'])

Output: Output:

Model 1 network...
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_4 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, None, None, 64)    1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, None, None, 64)    36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, None, None, 64)    0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, None, None, 128)   73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, None, None, 128)   147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, None, None, 128)   0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, None, None, 256)   295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, None, None, 256)   590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, None, None, 256)   0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, None, None, 512)   1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, None, None, 512)   0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, None, None, 512)   2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, None, None, 512)   0         
_________________________________________________________________
global_average_pooling2d_3 ( (None, 512)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 1024)              525312    
_________________________________________________________________
dense_6 (Dense)              (None, 10)                10250     
=================================================================
Total params: 15,250,250
Trainable params: 15,250,250
Non-trainable params: 0
_________________________________________________________________

(paraphrased from Frightera) (转述自 Frightera)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM