简体   繁体   中英

Pytorch - compute accuracy UNet multi-class segmentation

I'm trying to run on pytorch a UNet model for a multi-class image segmentation. I found an architecture of the model online that is apparently working... I have 100 classes, my input is corresponding to a tensor size [8, 3, 32, 32], my label is [8, 32, 32] and as expected my output is [8, 100, 32, 32].

I want to compute the accuracy for every iteration so I followed this code for the computation of the accuracy:

def multi_acc(pred, label):
    probs = torch.log_softmax(pred, dim = 1)
    _, tags = torch.max(probs, dim = 1)
    corrects = (tags == label).float()
    acc = corrects.sum()/len(corrects)
    acc = torch.round(acc)*100
    return acc

But then when i'm running the training i get an accuracy that is always the same as:

 Epoch : [2] [1/38311] Loss : 0.3168763518333435
 Acc: 102400.0
 Epoch : [2] [2/38311] Loss : 0.31527179479599
 Acc: 102400.0
 Epoch : [2] [3/38311] Loss : 0.2920961081981659
 Acc: 102400.0

And it keep going like this... If anyone has an idea to better understand that would be super great !

Thanks for the answers.

corrects is a 3-dimensional array (batch, wdith, height) or something like that. When you call acc = corrects.sum() / len(corrects) , len returns the size of the first dimension of the tensor, in this case 8 I think. Instead use .numel() to return the total number of elements in the 3-dimensional tensor. Also I recommend using torch.eq() . Also, don't round at the end. acc should be between 0 and 1 before rounding so if round it you'll always either get 0 or 1, which will correspond to 0 or 100 % accuracy after converting to percentage. Leave your accuracy metric unrounded and round it when you print it.

def multi_acc(pred, label):
    probs = torch.log_softmax(pred, dim = 1)
    _, tags = torch.max(probs, dim = 1)
    corrects = torch.eq(tags,label).int()
    acc = corrects.sum()/corrects.numel()
    return acc

You calculate the accuracy with:

acc = corrects.sum()/len(corrects)

corrects has a size of torch.Size([8, 32, 32]) , taking the sum with corrects.sum() gives you the number of correctly classified pixels, and there are a total of 8 * 32 * 32 = 8192 . The accuracy should be num_correct / num_total , but you're dividing it by len(corrects) == 8 . To get the total number of elements you can use torch.numel .

Another problem is that you're rounding your accuracy:

acc = torch.round(acc)*100

The accuracy is a value between 0 and 1. By rounding it, you'll get 0 for everything below 0.5 and 1 for everything else. That means you would only determine whether you've achieved over 50% accuracy. You need to remove the rounding entirely.

Applying these changes, you get the following function. I also removed the log_softmax , which leaves the order unchanged (larger values have larger probabilities). Since you're not using the probabilities, it has no effect:

def multi_acc(pred, label):
    _, tags = torch.max(pred, dim = 1)
    corrects = (tags == label).float()
    acc = corrects.sum() / corrects.numel()
    acc = acc * 100
    return acc

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM