I have run into the following error while training a BERT classifier. The
type(b_input_mask) = type(b_labels) = torch.Tensor
type(b_labels[i]) = tensor(1., dtype=torch.float64)
type(b_input_masks[i]) = class'torch.Tensor'
What could be the possible data type error here since I have not typecasted any variable to either long or double?
In a classification task, the data type for input labels should be Long but you assigned them as float64
type(b_labels[i]) = tensor(1., dtype=torch.float64)
=>
type(b_labels[i]) = tensor(1., dtype=torch.long)
You can use torch.Tensor.long
to convert tensor to expected long
type.
# Here, you can pass parameter like this in your call
..., labels = b_labels.long())
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.