簡體   English   中英

即使輸出層有任何大小,keras 神經網絡如何找到要屬性的類?

[英]How does keras neural network find the class to attribute even if the output layer has any size?

我有一個帶有二進制類(真或假)的數據樣本。 神經網絡賦予每個類權重,最大值將決定屬性類。 但是為什么即使輸出層沒有適當數量的神經元,keras 也能工作? (= 班級數 = 2 在我的情況下,0 或 1)。

import keras
from model import *

X_train, X_test, y_train, y_test = train_test_split(df_features, df_labels, test_size=0.25, random_state=10)

model = keras.Sequential([
    keras.layers.Flatten(input_shape=(len(X_test.columns),)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(128, activation='softmax') # Shouldn't be two here ?
])

model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# len(y_train.columns) == 1
history = model.fit(X_train, y_train, epochs=100, validation_split=0.25)

scores = model.evaluate(X_test, y_test, verbose=0)

print(model.metrics_names)
print('scores=', scores)

假設:它在最后添加了一個隱式層,或者它可能忽略了一些神經元,或者完全是別的什么?

編輯:添加數據

>>> print(y_train)
[0 0 0 ... 0 1 0]

>>> print(y_test)
      Class
1424      0
3150      1
2149      0
1700      0
4330      0
4200      0
# etc, ~1000 entries
>>> print('len(y_train)=', len(y_train))
len(y_train)= 2678
>>> print('len(y_test)=', len(y_test))
len(y_test)= 893

我相信問題在於您的損失sparse_cartegorical_crossentropy如何工作。 這種損失(與categorical_crossentropy相反)假設 y_actual 將作為標簽編碼格式而不是單熱編碼格式提供 意思是如果您要預測 5 個類,則 y_actual 數組提供為[0,2,4,1,2,2,3,3,1...]其中一維數組中的每個值代表一個5 個可能的班級中的班級編號。

讓我們直接從tf2 文檔中查看有關此損失的獨立使用的示例 -

y_true = [1, 2] #class number from 0 to 2
y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]] #3 class classification output
loss = tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred)
assert loss.shape == (2,)
loss.numpy()
[0.0513, 2.253]

在您的情況下,這意味着當您的模型返回 128 維輸出時,它假定此分類問題中有 128 個類。 但是,由於損失是sparse_categorical_crossentropy ,它等待接收 0-127 之間的單個數字,然后將其用於計算其誤差。

由於您總是在所有情況下都給它 0 或 1,因此它假定樣本所屬的實際類是 128 個類中的 0 類或 1 類,而沒有其他類。 因此,它的代碼運行但它是錯誤的,因為它不是從y_train (或y_test )作為二進制類讀取單個數字,而是假定它屬於 128 個其他類中的一個類。

print(y_train)
[0 0 0 ... 0 1 0]

#The first 0 here, is being considered as one class out of 128 other classes. 
#The code would still work if u changed that to say 105 instead of 0.
#Similarly for all the other 0s and 1s. 

希望這是有道理的。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM