簡體   English   中英

在 PyTorch 中本地測量多類分類的 F1 分數

[英]Measuring F1 score for multiclass classification natively in PyTorch

我試圖在 PyTorch 中本地實現宏 F1 分數(F-measure),而不是使用已經廣泛使用的sklearn.metrics.f1_score來直接在 GPU 上計算度量。

據我了解,為了計算宏 F1 分數,我需要計算所有標簽的靈敏度和精度的 F1 分數,然后取所有這些的平均值。

我的嘗試

我當前的實現如下所示:

def confusion_matrix(y_pred: torch.Tensor, y_true: torch.Tensor, n_classes: int):
    conf_matrix = torch.zeros([n_classes, n_classes], dtype=torch.int)
    y_pred = torch.argmax(y_pred, 1)
    for t, p in zip(y_true.view(-1), y_pred.view(-1)):
        conf_matrix[t.long(), p.long()] += 1
    return conf_matrix

def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    conf_matrix = confusion_matrix(y_pred, y_true, self.classes)
    TP = conf_matrix.diag()
    f1_scores = torch.zeros(self.classes, dtype=torch.float)
    for c in range(self.classes):
        idx = torch.ones(self.classes, dtype=torch.long)
        idx[c] = 0
        FP = conf_matrix[c, idx].sum()
        FN = conf_matrix[idx, c].sum()
        sensitivity = TP[c] / (TP[c] + FN + self.epsilon)
        precision = TP[c] / (TP[c] + FP + self.epsilon)
        f1_scores[c] += 2.0 * ((precision * sensitivity) / (precision + sensitivity + self.epsilon))
    return f1_scores.mean()

self.classes是標簽的數量,而self.epsilon是一個非常小的值,設置為10-e12可以防止DivisionByZeroError

訓練時,我計算每批的度量值,並將所有度量值的平均值作為最終分數。

問題

問題是,當我將自定義 F1 分數與 sklearn 的宏 F1 分數進行比較時,它們很少相等。

# example 1
eval_cce 0.5203, eval_f1 0.8068, eval_acc 81.5455, eval_f1_sci 0.8023,
test_cce 0.4784, test_f1 0.7975, test_acc 82.6732, test_f1_sci 0.8097
# example 2
eval_cce 0.3304, eval_f1 0.8211, eval_acc 87.4955, eval_f1_sci 0.8626,
test_cce 0.3734, test_f1 0.8183, test_acc 85.4996, test_f1_sci 0.8424
# example 3
eval_cce 0.4792, eval_f1 0.7982, eval_acc 81.8482, eval_f1_sci 0.8001,
test_cce 0.4722, test_f1 0.7905, test_acc 82.6533, test_f1_sci 0.8139

雖然我試圖掃描互聯網,但大多數情況都涉及二進制分類。 我還沒有找到一個例子來嘗試做我想做的事情。

我的問題

我的嘗試有什么明顯的問題嗎?

更新 (10.06.2020)

我還沒有弄清楚我的錯誤。 由於時間限制,我決定只使用sklearn提供的F1宏分數。 雖然它不能直接與 GPU 張量一起使用,但無論如何它對我的情況來說已經足夠快了。

但是,如果有人能弄清楚這一點,那就太好了,這樣任何可能偶然發現這個問題的人都可以解決他們的問題。

前段時間我在 Pytorch 中編寫了自己的實現:

from typing import Tuple

import torch


class F1Score:
    """
    Class for f1 calculation in Pytorch.
    """

    def __init__(self, average: str = 'weighted'):
        """
        Init.

        Args:
            average: averaging method
        """
        self.average = average
        if average not in [None, 'micro', 'macro', 'weighted']:
            raise ValueError('Wrong value of average parameter')

    @staticmethod
    def calc_f1_micro(predictions: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Calculate f1 micro.

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels

        Returns:
            f1 score
        """
        true_positive = torch.eq(labels, predictions).sum().float()
        f1_score = torch.div(true_positive, len(labels))
        return f1_score

    @staticmethod
    def calc_f1_count_for_label(predictions: torch.Tensor,
                                labels: torch.Tensor, label_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calculate f1 and true count for the label

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels
            label_id: id of current label

        Returns:
            f1 score and true count for label
        """
        # label count
        true_count = torch.eq(labels, label_id).sum()

        # true positives: labels equal to prediction and to label_id
        true_positive = torch.logical_and(torch.eq(labels, predictions),
                                          torch.eq(labels, label_id)).sum().float()
        # precision for label
        precision = torch.div(true_positive, torch.eq(predictions, label_id).sum().float())
        # replace nan values with 0
        precision = torch.where(torch.isnan(precision),
                                torch.zeros_like(precision).type_as(true_positive),
                                precision)

        # recall for label
        recall = torch.div(true_positive, true_count)
        # f1
        f1 = 2 * precision * recall / (precision + recall)
        # replace nan values with 0
        f1 = torch.where(torch.isnan(f1), torch.zeros_like(f1).type_as(true_positive), f1)
        return f1, true_count

    def __call__(self, predictions: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        """
        Calculate f1 score based on averaging method defined in init.

        Args:
            predictions: tensor with predictions
            labels: tensor with original labels

        Returns:
            f1 score
        """

        # simpler calculation for micro
        if self.average == 'micro':
            return self.calc_f1_micro(predictions, labels)

        f1_score = 0
        for label_id in range(1, len(labels.unique()) + 1):
            f1, true_count = self.calc_f1_count_for_label(predictions, labels, label_id)

            if self.average == 'weighted':
                f1_score += f1 * true_count
            elif self.average == 'macro':
                f1_score += f1

        if self.average == 'weighted':
            f1_score = torch.div(f1_score, len(labels))
        elif self.average == 'macro':
            f1_score = torch.div(f1_score, len(labels.unique()))

        return f1_score

您可以通過以下方式對其進行測試:

from sklearn.metrics import f1_score
import numpy as np
errors = 0
for _ in range(10):
    labels = torch.randint(1, 10, (4096, 100)).flatten()
    predictions = torch.randint(1, 10, (4096, 100)).flatten()
    labels1 = labels.numpy()
    predictions1 = predictions.numpy()

    for av in ['micro', 'macro', 'weighted']:
        f1_metric = F1Score(av)
        my_pred = f1_metric(predictions, labels)
        
        f1_pred = f1_score(labels1, predictions1, average=av)
        
        if not np.isclose(my_pred.item(), f1_pred.item()):
            print('!' * 50)
            print(f1_pred, my_pred, av)
            errors += 1

if errors == 0:
    print('No errors!')

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM