简体   繁体   中英

Keras - "ValueError: Error when checking target: expected activation_1 to have shape (None, 9) but got array with shape (9,1)

I'm building a model to classify text into one of 9 layers, and am having this error when running it. Activation 1 seems to refer to the Convolutional layer's input, but I'm unsure about what's wrong with the input.

num_classes=9
Y_train = keras.utils.to_categorical(Y_train, num_classes)
#Reshape data to add new dimension
X_train = X_train.reshape((100, 150, 1)) 
Y_train = Y_train.reshape((100, 9, 1)) 
model = Sequential() 
model.add(Conv1d(1, kernel_size=3, activation='relu', input_shape=(None, 1))) 
model.add(Dense(num_classes)) 
model.add(Activation('softmax')) 

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 
model.fit(x=X_train,y=Y_train, epochs=200, batch_size=20)

Running this results in the following error:

"ValueError: Error when checking target: expected activation_1 to have shape (None, 9) but got array with shape (9,1)

There are several typos and bugs in your code.

  1. Y_train = Y_train.reshape((100,9))

  2. Since you reshape X_train to (100,150,1), I guess your input step is 150, and channel is 1. So for the Conv1D , (there is a typo in your code), input_shape=(150,1) .

  3. You need to flatten your output of conv1d before feeding into Dense layer.

import keras
from keras import Sequential
from keras.layers import Conv1D, Dense, Flatten

X_train = np.random.normal(size=(100,150))
Y_train = np.random.randint(0,9,size=100)

num_classes=9
Y_train = keras.utils.to_categorical(Y_train, num_classes)
#Reshape data to add new dimension
X_train = X_train.reshape((100, 150, 1)) 
Y_train = Y_train.reshape((100, 9)) 
model = Sequential() 
model.add(Conv1D(2, kernel_size=3, activation='relu', input_shape=(150,1)))
model.add(Flatten())
model.add(Dense(num_classes, activation='softmax')) 

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) 
model.fit(x=X_train,y=Y_train, epochs=200, batch_size=20)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM