簡體   English   中英

Python Keras model 輸入與層不兼容

[英]Python Keras model input incompatible with layer

我有 2 個使用 keras 構建的模型版本,它們似乎構建得很好,但是在編譯和擬合時,我收到了同樣的錯誤。 我不確定問題是什么。

def build_cnn():
    model = models.Sequential([
        layers.Conv2D(32, (3, 3), input_shape=(32, 32, 3), padding='same', name='conv1'),
        layers.MaxPooling2D((2, 2),  name='maxpooling1'),
        layers.Conv2D(64, (3, 3), activation='relu', name='conv2'),
        layers.MaxPooling2D((2, 2), name = 'maxpooling2'),
        layers.Conv2D(64, (3, 3), activation='relu', name = 'conv3'),
        layers.Flatten(name = 'flatten'),
        layers.Dense(64, activation='relu', name='dense1'),
        layers.Dense(10, name='dense2')
        ], name='CNN')
    return model

def build_cnn2():
    model = models.Sequential([
        tf.keras.layers.Input(shape=(32,32,3)),
        tf.keras.layers.Conv2D(32, (3,3), padding='same', name='conv2d'),
        tf.keras.layers.MaxPool2D(pool_size=(2,2), name='maxpooling'),
        tf.keras.layers.Flatten(name='flatten'),
        tf.keras.layers.Dense(10, activation='softmax', name='dense'),
        ], name='conv2d')
    return model

def train(model):
    model.compile(optimizer='adam',
                 loss='binary_crossentropy',
                 metrics=['acc'])
    return model.fit(x_train, y_train,
                    epochs=5,
                    validation_split=0.1)

model2 = build_cnn()
log2 = train(model2)

我收到了model2 = build_cnn()model2 = build_cnn2()的訂單:

 ValueError: Input 0 of layer CNN is incompatible with the layer: : expected min_ndim=4, found ndim=2. Full shape received: (None, 30)

根據錯誤,您似乎沒有正確預處理數據 - 正如它所說,它需要4D - (None, h, w, c)但不知何故是(None, 30)

此外,您的模型的最后一個激活之一是None ,另一個設置為softmax但您設置了loss function binary_crossentropy而不是CategoricalCrossentropy

這是一個可能的解決方案(通過解決您的上述問題)。

def build_cnn():
    model = Sequential([
        layers.Conv2D(32, (3, 3), input_shape=(32, 32, 3), 
                       padding='same', name='conv1'),
        layers.MaxPooling2D((2, 2),  name='maxpooling1'),
        layers.Conv2D(64, (3, 3), activation='relu', name='conv2'),
        layers.MaxPooling2D((2, 2), name = 'maxpooling2'),
        layers.Conv2D(64, (3, 3), activation='relu', name = 'conv3'),
        layers.Flatten(name = 'flatten'),
        layers.Dense(64, activation='relu', name='dense1'),
        layers.Dense(10, activation='softmax', name='dense2')
        ], name='CNN')
    return model

def train(model, x_train, y_train):
    model.compile(optimizer='adam',
                 loss=tf.keras.losses.CategoricalCrossentropy(),
                 metrics=['acc'])
    return model.fit(x_train, y_train,
                    epochs=5, verbose=2,
                    validation_split=0.1)

model2 = build_cnn()

我使用mnist進行演示。 它的形狀是(28 x 28) ,但是當您的 model 采用 ( 32, 32, 3 ) - mnist需要進行預處理。 希望你能適應你的情況。

數據集

(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()

print(x_train.shape, y_train.shape)

# expand new axis, channel axis 
x_train = np.expand_dims(x_train, axis=-1)
print(x_train.shape)

# need 3 channel (instead of 1)
x_train = np.repeat(x_train, 3, axis=-1)
print(x_train.shape)

# it's always better to normalize 
x_train = x_train.astype('float32') / 255
print(x_train.shape)

# resize the input shape , i.e. old shape: 28, new shape: 32
x_train = tf.image.resize(x_train, [32,32]) # if we want to resize 
print(x_train.shape)

# train set / target 
y_train = tf.keras.utils.to_categorical(y_train , num_classes=10)
print(y_train.shape)
(60000, 28, 28) (60000,)
(60000, 28, 28, 1)
(60000, 28, 28, 3)
(60000, 28, 28, 3)
(60000, 32, 32, 3)
(60000, 10)

現在,您可以訓練您的 model。

log2 = train(model2, x_train , y_train)
Epoch 1/5
4ms/step - loss: 0.2792 - acc: 0.9133 - val_loss: 0.0676 - val_acc: 0.9815
Epoch 2/5
4ms/step - loss: 0.0454 - acc: 0.9864 - val_loss: 0.0400 - val_acc: 0.9883
Epoch 3/5
4ms/step - loss: 0.0336 - acc: 0.9892 - val_loss: 0.0415 - val_acc: 0.9900
Epoch 4/5
4ms/step - loss: 0.0235 - acc: 0.9926 - val_loss: 0.0359 - val_acc: 0.9907
Epoch 5/5
4ms/step - loss: 0.0163 - acc: 0.9948 - val_loss: 0.0295 - val_acc: 0.9918

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM