简体   繁体   中英

Tensorflow keras Conv1d input_shape problem, can anyone help me?

I'm trying to rewrite a keras model which used to classify satellite images,the model is a NN model and I want to rewrite it to CNN, I found the model from here .
the previous NN model is this:

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(1, nBands)),
    keras.layers.Dense(14, activation='relu'),
    keras.layers.Dense(2, activation='softmax')])

the original image shape is 6, 2054, 2044 , after reshape to two-dimensional array, it has become (2519025, 6) ,According to the article, the reason for reshape is:

We will now change the shape of the arrays to a two-dimensional array, which is expected by the majority of ML algorithms, where each row represents a pixel.

then it has been reshaped again,to (2519025, 1, 6)
I use Conv1D as a conv layer like this

model = keras.Sequential([
    keras.layers.Conv1D(filters=64, kernel_size=(3), activation='relu', padding = 'same',input_shape=(2519025,  6)),
    keras.layers.Dense(128, activation='relu',kernel_initializer='glorot_normal'),
    keras.layers.Dense(2, activation='softmax')])

I call the model like this: model.fit(xTrain, yTrain, epochs=2,batch_size=10)

the shapes of xTrain and yTrain are (2519025, 1, 6) (1679351, 1, 6)

I got this Waring:

ARNING:tensorflow:Model was constructed with shape (None, 2519025, 6) for input Tensor("conv1d_input:0", shape=(None, 2519025, 6), dtype=float32), but it was called on an input with incompatible shape (None, 1, 6).

what is the correct input_shape for the model or how can I change that NN model to use CNN?

As the warning says, the network expects the input to be in the shape of (None, 2519025, 6) where None is the batch size, but your xTrain and yTrain are in the shape of (2519025, 1, 6) (1679351, 1, 6). You can try the following to make your input shape to match the network input shapes:

xTrain = xTrain.reshape(2519025, 6)

However, If (2519025, 6) is the size of single input data, then your xTrain must be size of (#samples, 2519025, 6). Also, both networks are classifiers with two classes, but you mentioned that your yTrain is (1679351, 1, 6), which has to be (#samples, 2). You are going to get a separate error for that after fixing your input issue.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM