简体   繁体   English

运行时错误:预期标量类型 Float 但发现 Double(LSTM 分类器)

[英]RuntimeError: expected scalar type Float but found Double (LSTM classifier)

I'm training my LSTM classifier.我正在训练我的 LSTM 分类器。

epoch_num = 30

train_log = []
test_log = []
set_seed(111)
for epoch in range(1, epoch_num+1):

running_loss = 0    
train_loss = []
lstm_classifier.train()
for (inputs, labels) in tqdm(train_loader, desc='Training epoch ' + str(epoch), leave=False):        
    inputs, labels = inputs.to(device), labels.to(device)        
    optimizer.zero_grad()
    outputs = lstm_classifier(inputs)   
    loss = criterion(outputs, labels)
    loss.backward()                
    optimizer.step()        
    train_loss.append(loss.item())
train_log.append(np.mean(train_loss))

running_loss = 0
test_loss = []
lstm_classifier.eval()
with torch.no_grad():                
    for (inputs, labels) in tqdm(test_loader, desc='Test', leave=False):         
        inputs, labels = inputs.to(device), labels.to(device)        
        outputs = lstm_classifier(inputs)                       
        loss = criterion(outputs, labels)            
        test_loss.append(loss.item())
test_log.append(np.mean(test_loss))    
plt.plot(range(1, epoch+1), train_log, color='C0')
plt.plot(range(1, epoch+1), test_log, color='C1')
display.clear_output(wait=True)
display.display(plt.gcf())

error is:错误是:

RuntimeError Traceback (most recent call last) in () RuntimeError Traceback (最近一次调用最后一次) in ()

     23         print((labels.dtype))
     24         print(outputs[:,0].dtype)
---> 25         loss = criterion(outputs, labels)
     26         loss.backward()
     27         optimizer.step()

2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2822     if size_average is not None or reduce is not None:
   2823         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2824     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2825 
   2826 

RuntimeError: expected scalar type Float but found Double运行时错误:预期标量类型为 Float,但发现为 Double

How to fix it?如何解决?

RuntimeError: expected scalar type Float but found Double运行时错误:预期标量类型为 Float,但发现为 Double

The error at line loss = criterion(outputs, labels) is quite clear in that it requites your datatype to be float rather than double, but it doesn't explicitly say whether outputs or label is creating this.loss = criterion(outputs, labels)的错误非常清楚,因为它要求您的数据类型为浮点数而不是双精度数,但它没有明确说明是outputs还是label创建了这个。

My guess is its because of labels.我的猜测是因为标签。 Try converting it to float by doing labels.float()尝试通过执行labels.float()将其转换为浮动

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 RuntimeError:预期的标量类型 Double 但发现 Float - RuntimeError: expected scalar type Double but found Float RuntimeError:预期的标量类型为 Double 但发现为 Float - RuntimeError: expected scalar type Double but found Float Pytorch CNN 训练中的“RuntimeError: expected scalar type Double but found Float” - “RuntimeError: expected scalar type Double but found Float” in Pytorch CNN training RuntimeError: 预期标量类型 Long 但发现 Float - RuntimeError: expected scalar type Long but found Float Pytorch 为什么这里需要 is.float() 来解决 RuntimeError:预期标量类型 Float 但发现 Double - Pytorch why is .float() needed here for RuntimeError: expected scalar type Float but found Double 为什么不转换 Tensor 的 dtype 修复“运行时错误:预期标量类型 Double 但发现 Float”? - Why doesn't converting the dtype of a Tensor fix "RuntimeError: expected scalar type Double but found Float"? RuntimeError:预期的标量类型 Long 但发现 Float (Pytorch) - RuntimeError: expected scalar type Long but found Float (Pytorch) “RuntimeError: 预期标量类型 Long 但发现 Float”在转发过程中 - "RuntimeError: expected scalar type Long but found Float" in the forward process RuntimeError:预期的标量类型 Float 但发现 Long 神经网络 - RuntimeError: expected scalar type Float but found Long neural network Pytorch 几何:RuntimeError:预期的标量类型 Long 但发现 Float - Pytorch Geometric: RuntimeError: expected scalar type Long but found Float
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM