簡體   English   中英

如何在 PyTorch 中使用 Real-World-Weight 交叉熵損失

[英]How to use Real-World-Weight Cross-Entropy loss in PyTorch

我正在研究多類分類,其中一些錯誤比其他錯誤更嚴重。 因此,我想將成本合並到我的損失 function 中。我在 Real-World-Weight Cross-Entropy 的名稱下找到了它,如本文所述。 公式如下:

在此處輸入圖像描述

除了標准CrossEntropyLossweight參數外,我還沒有找到任何現成的實現,我認為它與我的用例完全不同(據我所知,錯誤分類一個類別的成本是相同的,無論它與哪個類別混淆了)。

我如何在 PyTorch 中應用它?

import torch.nn as nn
import torch

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
target = torch.empty(3, dtype=torch.long).random_(5)

cost_matrix = torch.zeros((5, 5))
cost_matrix[1, 0] = 0.4
cost_matrix[2, 0] = 0.1
cost_matrix[2, 1] = 0.9
cost_matrix[3, 0] = 0.4
cost_matrix[3, 1] = 0.9
cost_matrix[3, 2] = 0.1
cost_matrix[4, 0] = 0.1
cost_matrix[4, 1] = 0.4
cost_matrix[4, 2] = 0.9
cost_matrix[4, 3] = 0.1
cost_matrix[0, 1] = 0.4
cost_matrix[0, 2] = 0.1
cost_matrix[1, 2] = 0.9
cost_matrix[0, 3] = 0.4
cost_matrix[1, 3] = 0.9
cost_matrix[2, 3] = 0.1
cost_matrix[0, 4] = 0.1
cost_matrix[1, 4] = 0.4
cost_matrix[2, 4] = 0.9
cost_matrix[3, 4] = 0.1
 

我建議您調查此筆記本,它包含您問題中指出的論文作者的完整實現: The-Real-World-Weight-Crossentropy-Loss-Function

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM