簡體   English   中英

Keras 多分類器 ANN 中的 Argmax

[英]Argmax in a Keras multiclassifier ANN

我正在嘗試編寫一個 5 class 分類器 ANN,並且此代碼返回此錯誤:

    classifier = Sequential()
    
    classifier.add(Dense(units=10, input_dim=14, kernel_initializer='uniform', activation='relu'))
    
    classifier.add(Dense(units=6, kernel_initializer='uniform', activation='relu'))
    
    classifier.add(Dense(units=5, kernel_initializer='uniform', activation='softmax'))
    
    classifier.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    
    RD_Model = classifier.fit(X_train,y_train, batch_size=10 , epochs=10, verbose=1)


File "c:\Program Files\Python310\lib\site-packages\keras\backend.py", line 5119, in categorical_crossentropy
        target.shape.assert_is_compatible_with(output.shape)
    ValueError: Shapes (None, 1) and (None, 5) are incompatible

我認為這是因為我有一個概率矩陣而不是實際的 output,所以我一直在嘗試應用 argmax,但還沒有想出辦法

有人可以幫我嗎?

您是否嘗試過申請:

tf.keras.backend.argmax()

您可以使用以下命令定義lambda layer

from keras.layer import Lambda
from keras import backend as K

def argmax_layer(input):
  return K.argmax(input, axis=-1)

Keras提供了兩種用於定義 model 拓撲的范例。 您正在使用的代碼使用Sequential API 您可能必須恢復到Functional API

input_layer = Input(shape=(14,))
layer_1 = Dense(10, activation="relu")(input_layer)
layer_2 = Dense(6, activation="relu")(layer_1)
layer_3 = argmax_layer()(layer_2 )
output_layer= Dense(5, activation="linear")(layer_3 )

model = Model(inputs=input_layer, outputs=output_layer)

model.compile(optimizer='adam',
               loss='categorical_crossentropy', metrics=['accuracy'])

另一種選擇是實例化Keras Layer的繼承 class 。 https://www.tutorialspoint.com/keras/keras_customized_layer.htm

正如史努比博士所說,這確實是單熱編碼的問題......我錯過了這樣做,導致我的 model 無法正常工作。

所以我只對它進行了一次熱編碼:

encoder = LabelEncoder()
encoder.fit(y_train)
encoded_Y = encoder.transform(y_train)
# convert integers to dummy variables (i.e. one hot encoded)
dummy_y = np_utils.to_categorical(encoded_Y)

它在使用 dummy_y 后起作用。 謝謝您的幫助。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM