简体   繁体   English

pytorch 中的加权 mse 损失

[英]weighted mse loss in pytorch

def weighted_mse_loss(input_tensor, target_tensor, weight = 1):
    observation_dim = input_tensor.size()[-1]
    streched_tensor = ((input_tensor - target_tensor) ** 2).view(-1, observation_dim)
    entry_num = float(streched_tensor.size())[0]
    non_zero_entry_num = torch.sum(streched_tensor[:,0] != 0).float()
    weighted_tensor = torch.mm(
        ((input_tensor - target_tensor)**2).view(-1, observation_dim),
        (torch.diag(weight.float().view(-1)))
    )
    return torch.mean(weighted_tensor) * weight.nelement() * entry_num / non_zero_entry_num

I can't understand how the code gives weighted Mean Square Error loss.我无法理解代码如何给出加权均方误差损失。 I get that observation_dim is the final output dimension, (the class number I guess), and after that line, I don't get it.我知道observation_dim是最终的输出维度,(我猜是班级编号),在那行之后,我不明白。 Could someone help me figure out how the code calculates the loss?有人可以帮我弄清楚代码是如何计算损失的吗?

Thanks a lot.非常感谢。

    def weighted_mse_loss(input, target, weight):
        return (weight * (input - target) ** 2).mean()

try this, hope this can help.试试这个,希望这可以帮助。 All arguments need tensored.所有参数都需要张量。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM