簡體   English   中英

運行時錯誤:預期標量類型 Float 但發現 Double(LSTM 分類器)

[英]RuntimeError: expected scalar type Float but found Double (LSTM classifier)

我正在訓練我的 LSTM 分類器。

epoch_num = 30

train_log = []
test_log = []
set_seed(111)
for epoch in range(1, epoch_num+1):

running_loss = 0    
train_loss = []
lstm_classifier.train()
for (inputs, labels) in tqdm(train_loader, desc='Training epoch ' + str(epoch), leave=False):        
    inputs, labels = inputs.to(device), labels.to(device)        
    optimizer.zero_grad()
    outputs = lstm_classifier(inputs)   
    loss = criterion(outputs, labels)
    loss.backward()                
    optimizer.step()        
    train_loss.append(loss.item())
train_log.append(np.mean(train_loss))

running_loss = 0
test_loss = []
lstm_classifier.eval()
with torch.no_grad():                
    for (inputs, labels) in tqdm(test_loader, desc='Test', leave=False):         
        inputs, labels = inputs.to(device), labels.to(device)        
        outputs = lstm_classifier(inputs)                       
        loss = criterion(outputs, labels)            
        test_loss.append(loss.item())
test_log.append(np.mean(test_loss))    
plt.plot(range(1, epoch+1), train_log, color='C0')
plt.plot(range(1, epoch+1), test_log, color='C1')
display.clear_output(wait=True)
display.display(plt.gcf())

錯誤是:

RuntimeError Traceback (最近一次調用最后一次) in ()

     23         print((labels.dtype))
     24         print(outputs[:,0].dtype)
---> 25         loss = criterion(outputs, labels)
     26         loss.backward()
     27         optimizer.step()

2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2822     if size_average is not None or reduce is not None:
   2823         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2824     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2825 
   2826 

運行時錯誤:預期標量類型為 Float,但發現為 Double

如何解決?

運行時錯誤:預期標量類型為 Float,但發現為 Double

loss = criterion(outputs, labels)的錯誤非常清楚,因為它要求您的數據類型為浮點數而不是雙精度數,但它沒有明確說明是outputs還是label創建了這個。

我的猜測是因為標簽。 嘗試通過執行labels.float()將其轉換為浮動

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM