繁体   English   中英

使用 pytorch 的不平衡数据的焦点损失

[英]focal loss for imbalanced data using pytorch

我想使用 pytorch 对多类不平衡数据使用焦点损失。 我搜索得到并尝试使用此代码,但出现错误


class_weights=tf.constant([0.21, 0.45, 0.4, 0.46, 0.48, 0.49])

loss_fn=nn.CrossEntropyLoss(weight=class_weights,reduction='mean')

并在火车 function 中使用它



    preds = model(sent_id, mask, labels)
   
     # compu25te the validation loss between actual and predicted values
    alpha=0.25
    gamma=2
    ce_loss = loss_fn(preds, labels)
    pt = torch.exp(-ce_loss)
    focal_loss = (alpha * (1-pt)**gamma * ce_loss).mean()

错误是

TypeError: cannot assign 'tensorflow.python.framework.ops.EagerTensor' object to buffer 'weight' (torch Tensor or None required)

在这一行

loss_fn=nn.CrossEntropyLoss(weight=class_weights,reduction='mean')

您正在混合 tensorflow 和 pytorch 对象。

尝试:

class_weights=torch.tensor([0.21, ...], requires_grad=False)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM