簡體   English   中英

`.fit()` 的輸入應該有 4 級

[英]Input to `.fit()` should have rank 4

我使用 keras 學習數據增強。 鏈接: https : //keras.io/preprocessing/image/該鏈接使用 CNN,但是當我嘗試使用下面給出的密集層時,我收到錯誤消息。 錯誤位於 datagen.fit()。 描述說輸入應該是等級4。如何解決?

#import dataset
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

#change shape from image to vector
X_train = X_train.reshape(50000, 32 * 32 * 3)
X_test = X_test.reshape(10000, 32 * 32 * 3)

#preprocess
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255.0
X_test /= 255.0

#change labels from numeric to one hot encoded
Y_train = to_categorical(y_train, 10)
Y_test =  to_categorical(y_test, 10)

model = Sequential()
model.add(Dense(1024, input_shape=(3072, )))
model.add(Activation('relu'))
model.add(Dense(10))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])


datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)


datagen.fit(X_train)

# fits the model on batches with real-time data augmentation:
model.fit_generator(datagen.flow(X_train, Y_train, batch_size=32),
                    steps_per_epoch=len(X_train) / 32, epochs=epochs,verbose=1,
                        validation_data=datagen.flow(X_test, Y_test, batch_size=32))

錯誤

---------------------------------------------------------------------------
ValueError                         
---> 57 datagen.fit(X_train)


ValueError: Input to `.fit()` should have rank 4. Got array with shape: (50000, 3072)

你重塑了你的陣列。 ImageDataGenerator 需要 4 級輸入矩陣(圖像索引、高度、寬度、深度)。 您的 Reshaping 給出了 2 級輸入矩陣。 因此錯誤。 解決方法是刪除重塑,然后在第一個密集層上方添加一個 CNN 層(只是建議)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM