[英]Python train convolutional neural network on csv numpy error input shape
[英]Error in Convolutional Neural network for input shape
我有1000張28 * 28分辨率的圖像。 我將這1000張圖像轉換為numpy數組,並形成了一個大小為(1000,28,28)的新數組。 因此,在使用keras創建卷積層時,將輸入shape(X值)指定為(1000,28,28),將輸出shape(Y值)指定為(1000,10)。 因為我有1000個示例是輸入和10個類別的輸出。
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),activation='relu',kernel_initializer='he_normal',input_shape=(1000,28,28)))
.
.
.
model.fit(train_x,train_y,batch_size=32,epochs=10,verbose=1)
因此,在使用fit
函數時,它顯示ValueError: Error when checking input: expected conv2d_1_input to have 4 dimensions, but got array with shape (1000, 28, 28)
為錯誤。 請幫助我為CNN提供適當的輸入和輸出尺寸。
碼:
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),activation='relu',kernel_initializer='he_normal',input_shape=(4132,28,28)))
model.add(MaxPooling2D((2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(Dropout(0.4))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(10, activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,optimizer=keras.optimizers.Adam(),metrics=['accuracy'])
model.summary()
train_x = numpy.array([train_x])
model.fit(train_x,train_y,batch_size=32,epochs=10,verbose=1)
您需要將channel
設置為1
:( input_shape
)的輸入更改為4維,並且需要將卷積層的input_shape
更改為( input_shape
(28, 28, 1)
:
model.add(Conv2D(32, kernel_size=(3, 3),...,input_shape=(28,28,1)))
您的numpy數組需要第四個維度,通用標准是使用第一個維度對樣本進行編號,因此將(1000,28,28)更改為(1,1000,28,28)。
您可以在此處了解更多信息。
從您的輸入中看起來您正在使用tensorflow作為后端。
在keras中, input_shape
應該始終為3維。 對於tensorflow作為后端,模型的input_shape
將是
input_shape = [img_height,img_width,channels(depth)]
在你的情況下張量流后端應該是
input_shape = [28,28,1]
並且train_x
的形狀應為
train_x = [batch_size,img_height,img_width,channels(depth)]
在你的情況下
train_x = [1000,28,28,1]
當您使用灰度圖像時,圖像的尺寸將為(image_height,image_width),因此您必須為圖像添加額外的尺寸,從而導致(image_height,image_width,1)'1'表示圖像的深度,對於灰度值為“ 1”,對於rgb為“ 3”。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.