简体   繁体   中英

Pytorch Error, RuntimeError: expected scalar type Long but found Double

I have run into the following error while training a BERT classifier. The

type(b_input_mask) = type(b_labels) = torch.Tensor      

type(b_labels[i]) = tensor(1., dtype=torch.float64)

type(b_input_masks[i]) = class'torch.Tensor'

What could be the possible data type error here since I have not typecasted any variable to either long or double?

Thanks in advance! 错误堆栈跟踪

In a classification task, the data type for input labels should be Long but you assigned them as float64

type(b_labels[i]) = tensor(1., dtype=torch.float64)

=>

type(b_labels[i]) = tensor(1., dtype=torch.long)

You can use torch.Tensor.long to convert tensor to expected long type.

# Here, you can pass parameter like this in your call
..., labels = b_labels.long())

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM