I'm training my LSTM classifier.
epoch_num = 30
train_log = []
test_log = []
set_seed(111)
for epoch in range(1, epoch_num+1):
running_loss = 0
train_loss = []
lstm_classifier.train()
for (inputs, labels) in tqdm(train_loader, desc='Training epoch ' + str(epoch), leave=False):
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = lstm_classifier(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
train_loss.append(loss.item())
train_log.append(np.mean(train_loss))
running_loss = 0
test_loss = []
lstm_classifier.eval()
with torch.no_grad():
for (inputs, labels) in tqdm(test_loader, desc='Test', leave=False):
inputs, labels = inputs.to(device), labels.to(device)
outputs = lstm_classifier(inputs)
loss = criterion(outputs, labels)
test_loss.append(loss.item())
test_log.append(np.mean(test_loss))
plt.plot(range(1, epoch+1), train_log, color='C0')
plt.plot(range(1, epoch+1), test_log, color='C1')
display.clear_output(wait=True)
display.display(plt.gcf())
RuntimeError Traceback (most recent call last) in ()
23 print((labels.dtype))
24 print(outputs[:,0].dtype)
---> 25 loss = criterion(outputs, labels)
26 loss.backward()
27 optimizer.step()
2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction)
2822 if size_average is not None or reduce is not None:
2823 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 2824 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
2825
2826
RuntimeError: expected scalar type Float but found Double
How to fix it?
RuntimeError: expected scalar type Float but found Double
The error at line loss = criterion(outputs, labels)
is quite clear in that it requites your datatype to be float rather than double, but it doesn't explicitly say whether outputs
or label
is creating this.
My guess is its because of labels. Try converting it to float by doing labels.float()
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.