简体   繁体   中英

Pytorch lightning metrics: ValueError: preds and target must have same number of dimensions, or one additional dimension for preds

Googling this gets you no where, so I decided to help future me and others by posting this as a searchable question.


def __init__():
    ...
    self.val_acc = pl.metrics.Accuracy()

def validation_step(self, batch, batch_index):
    ...
    self.val_acc.update(log_probs, label_batch)

gives

ValueError: preds and target must have same number of dimensions, or one additional dimension for preds

for log_probs.shape == (16, 4) and for label_batch.shape == (16, 4)

What's the issue?

pl.metrics.Accuracy() expects a batch of dtype=torch.long labels, not one-hot encoded labels.

Thus, it should be fed

self.val_acc.update(log_probs, torch.argmax(label_batch.squeeze(), dim=1))


This is just the same as torch.nn.CrossEntropyLoss

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM