简体   繁体   中英

Convolution Neural Network input_shape dimension error (KERAS ,PYTHON)

I have a train dataset of the following shape: (300, 5, 720)

 [[[   6.   11.  389. ...,    0.    0.    0.]
   [   2.    0.    0. ...,   62.    0.    0.]
   [   0.    0.   18. ...,    0.    0.    0.]
   [  38.  201.   47. ...,    0.  108.    0.]
   [   0.    0.    1. ...,    0.    0.    0.]]

   [[ 136.   95.    0. ...,    0.    0.    0.]
   [  85.   88.   85. ...,    0.   31.    0.]
   [   0.    0.    0. ...,    0.    0.    0.]
   [   0.    0.    0. ...,    0.    0.    0.]
   [  13.   19.    0. ...,    0.    0.    0.]]]

I am trying to pass each sample as input to the cnn model, each input is of size (5,720) ,I am using the following model in keras:

    cnn = Sequential()

    cnn.add(Conv2D(64, (5, 50),
    padding="same",
    activation="relu",data_format="channels_last",
    input_shape=in_shape))

   cnn.add(MaxPooling2D(pool_size=(2,2),data_format="channels_last"))

   cnn.add(Flatten())
   cnn.add(Dropout(0.5))

   cnn.add(Dense(number_of_classes, activation="softmax"))

   cnn.compile(loss="categorical_crossentropy", optimizer="adam", metrics=
   ['accuracy'])

   cnn.fit(x_train, y_train,
     batch_size=batch_size,
     epochs=epochs,
     validation_data=(x_test, y_test),
     shuffle=True)

I am using input shape as:

    rows,cols=x_train.shape[1:]
    in_shape=(rows,cols,1)

but I am getting the following error:

ValueError: Error when checking model input: expected conv2d_1_input to have 4 dimensions, but got array with shape (300, 5, 720)

How can I fix this error?

This is one of the most classic errors in Convolutions in Keras. Its origin lies in the fact that when using channels_last input dimension you need to make your input to have dimension (height, width, channels) even in case when you have only one channel. So basically reshaping:

x_train = x_train.reshape((x_train.shape[0], 5, 720, 1)

should solve your problem.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM