簡體   English   中英

為什么 Tensorflow 和 Pytorch CrossEntropy 損失對同一示例返回不同的值

[英]Why is the Tensorflow and Pytorch CrossEntropy loss returns different values for same example

我嘗試獲取 Tensorflow 和 Pytorch CrossEntropyLoss 但它返回不同的值,我不知道為什么。 請找到以下代碼和結果。 感謝您的投入和幫助。

import tensorflow as tf
import numpy as np

y_true = [3, 3, 1]
y_pred = [
    [0.3377, 0.4867, 0.8842, 0.0854, 0.2147],
    [0.4853, 0.0468, 0.6769, 0.5482, 0.1570],
    [0.0976, 0.9899, 0.6903, 0.0828, 0.0647]
]

scce3 = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.AUTO)
loss3 = scce3(y_true, y_pred).numpy()
print(loss3)

上面的結果是:1.69

Pytorch 損失:

from torch import nn
import torch
loss = nn.CrossEntropyLoss()
y_true = torch.Tensor([3, 3, 1]).long()
y_pred = torch.Tensor([
    [0.3377, 0.4867, 0.8842, 0.0854, 0.2147],
    [0.4853, 0.0468, 0.6769, 0.5482, 0.1570],
    [0.0976, 0.9899, 0.6903, 0.0828, 0.0647]
])
loss2 = loss(y_pred, y_true)
print(loss2)

以上損失值為:1.5

Tensorflow 的 CrossEntropy 期望概率作為輸入(即tf.nn.softmax操作后的值),而 PyTorch 的 CrossEntropyLoss 期望原始輸入,或更常見的名稱,logits。 如果使用 softmax 操作,值應該相同:

import tensorflow as tf
import numpy as np

y_true = [3, 3, 1]
y_pred = [
    [0.3377, 0.4867, 0.8842, 0.0854, 0.2147],
    [0.4853, 0.0468, 0.6769, 0.5482, 0.1570],
    [0.0976, 0.9899, 0.6903, 0.0828, 0.0647]
]

scce3 = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.AUTO)
loss3 = scce3(y_true, tf.nn.softmax(y_pred)).numpy()
print(loss3)

>>> 1.5067214
from torch import nn
import torch
loss = nn.CrossEntropyLoss()
y_true = torch.Tensor([3, 3, 1]).long()
y_pred = torch.Tensor([
    [0.3377, 0.4867, 0.8842, 0.0854, 0.2147],
    [0.4853, 0.0468, 0.6769, 0.5482, 0.1570],
    [0.0976, 0.9899, 0.6903, 0.0828, 0.0647]
])
loss2 = loss(y_pred, y_true)
print(loss2)

>>> tensor(1.5067)

由於數值穩定性的LogSumExp技巧,通常建議使用原始輸入 (logits)。 如果您使用的是 Tensorflow,我建議您改用tf.nn.softmax_cross_entropy_with_logits function 或其對應的稀疏函數。 編輯: SparseCategoricalCrossentropy class 也有一個關鍵字參數from_logits=False可以設置為True達到同樣的效果。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM