簡體   English   中英

RuntimeError: 預期 object 的標量類型 Long 但在調用 _thnn_nll_loss_forward 時獲得了參數 #2 'target' 的標量類型 Float

[英]RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target' in call to _thnn_nll_loss_forward

我正在 Tweeter 數據集上嘗試 Bert。 我遇到以下錯誤消息。

# set initial loss to infinite
best_valid_loss = float('inf')

# empty lists to store training and validation loss of each epoch
train_losses=[]
valid_losses=[]

#for each epoch
for epoch in range(epochs):
 
    print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))

#train model
train_loss, _ = train()

#evaluate model
valid_loss, _ = evaluate()

#save the best model
if valid_loss < best_valid_loss:
    best_valid_loss = valid_loss
    torch.save(model.state_dict(), 'saved_weights.pt')

# append training and validation loss
train_losses.append(train_loss)
valid_losses.append(valid_loss)

print(f'\nTraining Loss: {train_loss:.3f}')
print(f'Validation Loss: {valid_loss:.3f}')

這是一個很長的代碼。 搜索問題導致我將.float() 更改為 long()。 我已經這樣做了。 請建議我解決方案。 非常重要:相同的代碼在另一個數據集(具有相同數量的列和相同類型的數據)上運行良好,但不適用於推文數據。 (唯一的區別是大小。以前有 5500 個條目,而推文數據集有 10000 個條目)

我已經搜索了很多上述錯誤。 最后,我發現出現上述錯誤的主要原因是“沒有正確清理數據集”。 原因(在我的情況下)是 label 列中的值顯示為浮點數,而不是整數。 通過使用 pandas 我將所有浮點值更改為 int 之后,代碼成功運行。 因此,花更多的時間來清理數據而不是編寫代碼。 謝謝你。

您是否有要預測的分類目標(即分類)? 假設一個名為 y 的二進制目標有 0 和 1? 您的目標變量可能不是編碼為int64 (即 Long 格式),而是編碼為int32 只需在創建DataLoader之前將該目標變量轉換為 64 位,就可以了。

有兩種可能的方式:

  • 要么保留為 numpy 陣列並在那里進行轉換,
  • 或轉換為torch數組,然后轉換為Long格式。
import torch
import numpy as np
y

# array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1,
#       1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
#       1])

y.dtype

# dtype('int32')
# numpy
y = y.astype('int64')

y
# array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1,
#       1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
#       1], dtype=int64)

或者

# torch
y = torch.from_numpy(y).to(dtype=torch.long)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM