簡體   English   中英

Pytorch 錯誤,RuntimeError:預期標量類型 Long 但發現 Double

[英]Pytorch Error, RuntimeError: expected scalar type Long but found Double

我在訓練 BERT 分類器時遇到了以下錯誤。

type(b_input_mask) = type(b_labels) = torch.Tensor      

type(b_labels[i]) = tensor(1., dtype=torch.float64)

type(b_input_masks[i]) = class'torch.Tensor'

由於我沒有將任何變量類型轉換為 long 或 double,這里可能出現的數據類型錯誤是什么?

提前致謝! 錯誤堆棧跟蹤

在分類任務中,輸入標簽的數據類型應該是 Long,但您將它們分配為 float64

type(b_labels[i]) = tensor(1., dtype=torch.float64)

=>

type(b_labels[i]) = tensor(1., dtype=torch.long)

您可以使用torch.Tensor.long將張量轉換為預期long類型。

# Here, you can pass parameter like this in your call
..., labels = b_labels.long())

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM