簡體   English   中英

卷積神經網絡中的形狀誤差

[英]Shapes error in Convolutional Neural Network

我正在嘗試訓練具有以下結構的神經網絡:

model = Sequential()

model.add(Conv1D(filters = 300, kernel_size = 5, activation='relu', input_shape=(4000, 1)))
model.add(Conv1D(filters = 300, kernel_size = 5, activation='relu'))
model.add(MaxPooling1D(3))
model.add(Conv1D(filters = 320, kernel_size = 5, activation='relu'))
model.add(MaxPooling1D(3))
model.add(Dropout(0.5))

model.add(Dense(num_labels, activation='softmax'))

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

return model

我收到此錯誤:

expected dense_1 to have shape (442, 3) but got array with shape (3, 1)

我的輸入是一組短語(共12501個),它們已針對4000個最相關的單詞進行了標記,並且有3種可能的分類。 因此,我的輸入是train_x.shape =(12501,4000)。 我將其重塑為Conv1D層的(12501,4000,1)。 現在,我的train_y.shape =(12501,3),然后將其重塑為(12501,3,1)。

我正在使用fit函數,如下所示:

model.fit(train_x, train_y, batch_size=32, epochs=10, verbose=1, validation_split=0.2, shuffle=True)

我究竟做錯了什么?

無需轉換標簽形狀即可分類。 您可以查看您的網絡結構。

print(model.summary())
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv1d_1 (Conv1D)            (None, 3996, 300)         1800      
_________________________________________________________________
conv1d_2 (Conv1D)            (None, 3992, 300)         450300    
_________________________________________________________________
max_pooling1d_1 (MaxPooling1 (None, 1330, 300)         0         
_________________________________________________________________
conv1d_3 (Conv1D)            (None, 1326, 320)         480320    
_________________________________________________________________
max_pooling1d_2 (MaxPooling1 (None, 442, 320)          0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 442, 320)          0         
_________________________________________________________________
dense_1 (Dense)              (None, 442, 3)            963       
=================================================================
Total params: 933,383
Trainable params: 933,383
Non-trainable params: 0
_________________________________________________________________

模型的最后一個輸出是(None, 442, 3) ,但是標簽的形狀是(None, 3, 1) 您最終應該以全局池化層GlobalMaxPooling1D()或Flatten層Flatten()結尾,將3D輸出轉換為2D輸出,以進行分類或回歸。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM