[英]Tensorflow Macro F1 Score for multiclass and also for binary classification
我正在嘗試訓練 2 個 1D Conv 神經網絡 - 一個用於多類分類問題,第二個用於二元分類問題。 我的指標之一必須是這兩個問題的 Macro F1 分數。 但是我在使用來自 tensorflow 插件的tfa.metrics.F1Score
時遇到問題。
我有 3 個類編碼為 0、1、2。
網絡的最后一層和編譯方法如下所示( int_sequeces_input
是輸入層):
preds = layers.Dense(3, activation="softmax")(x)
model = keras.Model(int_sequences_input, preds)
f1_macro = F1Score(num_classes=3, average='macro')
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy',f1_macro])
但是,當我運行model.fit()
時,出現以下錯誤:
ValueError: Dimension 0 in both shapes must be equal, but are 3 and 1. Shapes are [3] and [1]. for '{{node AssignAddVariableOp_7}} = AssignAddVariableOp[dtype=DT_FLOAT](AssignAddVariableOp_7/resource, Sum_6)' with input shapes: [], [1].
X_train - (23658, 150)
y_train - (23658,)
我有 2 個類編碼為 0,1
網絡的最后一層和編譯方法如下所示( int_sequeces_input
是輸入層):
preds = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(int_sequences_input, preds)
print(model.summary())
f1_macro = F1Score(num_classes=2, average='macro')
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy',f1_macro])
同樣,當我運行model.fit()
時出現錯誤:
ValueError: Dimension 0 in both shapes must be equal, but are 2 and 1. Shapes are [2] and [1]. for '{{node AssignAddVariableOp_4}} = AssignAddVariableOp[dtype=DT_FLOAT](AssignAddVariableOp_4/resource, Sum_3)' with input shapes: [], [1].
X_train - (15770, 150)
y_train - (15770,)
所以我的問題是:如何使用宏觀 F1 分數評估我的兩個模型? 如何修復我的實現以使其與tfa.metrics.F1Score
一起使用? 或者有沒有其他方法可以在不使用tfa.metrics.F1Score
的情況下計算宏觀 F1 分數? 謝謝。
從其文檔頁面查看用法示例。
metric = tfa.metrics.F1Score(num_classes=3, threshold=0.5)
y_true = np.array([[1, 1, 1],
[1, 0, 0],
[1, 1, 0]], np.int32)
y_pred = np.array([[0.2, 0.6, 0.7],
[0.2, 0.6, 0.6],
[0.6, 0.8, 0.0]], np.float32)
metric.update_state(y_true, y_pred)
您可以看到它希望標簽采用 one-hot 格式。
但是考慮到您上面提到的形狀:
shapes of data:
X_train - (23658, 150)
y_train - (23658,)
看起來您的標簽采用索引格式。 嘗試使用tf.one_hot(y_train, num_classes)
將它們轉換為一個熱點。 您還需要將損失更改為loss='categorical_crossentropy'
。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.