简体   繁体   English

Pytorch - 计算精度 UNet 多类分割

[英]Pytorch - compute accuracy UNet multi-class segmentation

I'm trying to run on pytorch a UNet model for a multi-class image segmentation.我正在尝试在 pytorch 和 UNet model 上运行多类图像分割。 I found an architecture of the model online that is apparently working... I have 100 classes, my input is corresponding to a tensor size [8, 3, 32, 32], my label is [8, 32, 32] and as expected my output is [8, 100, 32, 32].我在网上找到了一个 model 的架构,它显然正在工作......我有 100 个类,我的输入对应于张量大小 [8, 3, 32, 32],我的 label 是 [8, 32, 32] 和预计我的 output 是 [8, 100, 32, 32]。

I want to compute the accuracy for every iteration so I followed this code for the computation of the accuracy:我想计算每次迭代的准确性,所以我按照以下代码计算准确性:

def multi_acc(pred, label):
    probs = torch.log_softmax(pred, dim = 1)
    _, tags = torch.max(probs, dim = 1)
    corrects = (tags == label).float()
    acc = corrects.sum()/len(corrects)
    acc = torch.round(acc)*100
    return acc

But then when i'm running the training i get an accuracy that is always the same as:但是当我进行训练时,我得到的准确度始终与以下内容相同:

 Epoch : [2] [1/38311] Loss : 0.3168763518333435
 Acc: 102400.0
 Epoch : [2] [2/38311] Loss : 0.31527179479599
 Acc: 102400.0
 Epoch : [2] [3/38311] Loss : 0.2920961081981659
 Acc: 102400.0

And it keep going like this... If anyone has an idea to better understand that would be super great !它继续这样......如果有人有一个更好理解的想法,那将是超级棒的!

Thanks for the answers.感谢您的回答。

corrects is a 3-dimensional array (batch, wdith, height) or something like that. corrects是一个 3 维数组(batch、wdith、height)或类似的东西。 When you call acc = corrects.sum() / len(corrects) , len returns the size of the first dimension of the tensor, in this case 8 I think.当您调用acc = corrects.sum() / len(corrects)时, len返回张量的第一维的大小,在这种情况下我认为是 8。 Instead use .numel() to return the total number of elements in the 3-dimensional tensor.而是使用.numel()返回 3 维张量中的元素总数。 Also I recommend using torch.eq() .我还建议使用torch.eq() Also, don't round at the end.另外,最后不要圆。 acc should be between 0 and 1 before rounding so if round it you'll always either get 0 or 1, which will correspond to 0 or 100 % accuracy after converting to percentage. acc 应该在舍入之前介于 0 和 1 之间,因此如果舍入它,您将始终得到 0 或 1,这将对应于转换为百分比后的 0 或 100% 准确度。 Leave your accuracy metric unrounded and round it when you print it.保留您的准确度指标,并在打印时将其四舍五入。

def multi_acc(pred, label):
    probs = torch.log_softmax(pred, dim = 1)
    _, tags = torch.max(probs, dim = 1)
    corrects = torch.eq(tags,label).int()
    acc = corrects.sum()/corrects.numel()
    return acc

You calculate the accuracy with:您可以使用以下方法计算准确性:

acc = corrects.sum()/len(corrects)

corrects has a size of torch.Size([8, 32, 32]) , taking the sum with corrects.sum() gives you the number of correctly classified pixels, and there are a total of 8 * 32 * 32 = 8192 . corrects的大小为torch.Size([8, 32, 32]) ,用 corrects.sum( corrects.sum()求和给出正确分类的像素数,总共有8 * 32 * 32 = 8192 The accuracy should be num_correct / num_total , but you're dividing it by len(corrects) == 8 .准确度应该是num_correct / num_total ,但是你将它除以len(corrects) == 8 To get the total number of elements you can use torch.numel .要获取元素的总数,您可以使用torch.numel

Another problem is that you're rounding your accuracy:另一个问题是你正在四舍五入你的准确性:

acc = torch.round(acc)*100

The accuracy is a value between 0 and 1. By rounding it, you'll get 0 for everything below 0.5 and 1 for everything else.准确度是一个介于 0 和 1 之间的值。通过四舍五入,对于低于 0.5 的所有内容,您将获得 0,对于其他所有内容,您将获得 1。 That means you would only determine whether you've achieved over 50% accuracy.这意味着您只能确定您是否已达到 50% 以上的准确率。 You need to remove the rounding entirely.您需要完全删除舍入。

Applying these changes, you get the following function.应用这些更改,您将获得以下 function。 I also removed the log_softmax , which leaves the order unchanged (larger values have larger probabilities).我还删除了log_softmax ,它使顺序保持不变(较大的值具有较大的概率)。 Since you're not using the probabilities, it has no effect:由于您没有使用概率,因此没有效果:

def multi_acc(pred, label):
    _, tags = torch.max(pred, dim = 1)
    corrects = (tags == label).float()
    acc = corrects.sum() / corrects.numel()
    acc = acc * 100
    return acc

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM