简体   繁体   中英

RuntimeError: expected scalar type Float but found Double (LSTM classifier)

I'm training my LSTM classifier.

epoch_num = 30

train_log = []
test_log = []
set_seed(111)
for epoch in range(1, epoch_num+1):

running_loss = 0    
train_loss = []
lstm_classifier.train()
for (inputs, labels) in tqdm(train_loader, desc='Training epoch ' + str(epoch), leave=False):        
    inputs, labels = inputs.to(device), labels.to(device)        
    optimizer.zero_grad()
    outputs = lstm_classifier(inputs)   
    loss = criterion(outputs, labels)
    loss.backward()                
    optimizer.step()        
    train_loss.append(loss.item())
train_log.append(np.mean(train_loss))

running_loss = 0
test_loss = []
lstm_classifier.eval()
with torch.no_grad():                
    for (inputs, labels) in tqdm(test_loader, desc='Test', leave=False):         
        inputs, labels = inputs.to(device), labels.to(device)        
        outputs = lstm_classifier(inputs)                       
        loss = criterion(outputs, labels)            
        test_loss.append(loss.item())
test_log.append(np.mean(test_loss))    
plt.plot(range(1, epoch+1), train_log, color='C0')
plt.plot(range(1, epoch+1), test_log, color='C1')
display.clear_output(wait=True)
display.display(plt.gcf())

error is:

RuntimeError Traceback (most recent call last) in ()

     23         print((labels.dtype))
     24         print(outputs[:,0].dtype)
---> 25         loss = criterion(outputs, labels)
     26         loss.backward()
     27         optimizer.step()

2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
   2822     if size_average is not None or reduce is not None:
   2823         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2824     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
   2825 
   2826 

RuntimeError: expected scalar type Float but found Double

How to fix it?

RuntimeError: expected scalar type Float but found Double

The error at line loss = criterion(outputs, labels) is quite clear in that it requites your datatype to be float rather than double, but it doesn't explicitly say whether outputs or label is creating this.

My guess is its because of labels. Try converting it to float by doing labels.float()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM