简体   繁体   中英

Pytorch why is .float() needed here for RuntimeError: expected scalar type Float but found Double

Simple question, i wanted to experiment with the simplest possible.network, but i kept running into RuntimeError: expected scalar type Float but found Double unless i casted data into .float() (see below code with comment)

What i dont understand is, why is this casting needed? data is already a torch.float64 type. Whys the explicit re-casting in the output = model(data.float()) line needed?

Code

from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
from sklearn.datasets import make_classification
from torch.utils.data import TensorDataset, DataLoader

# =============================================================================
# Simplest Example
# =============================================================================
X, y = make_classification()
X, y = torch.tensor(X), torch.tensor(y)
print("X Shape :{}".format(X.shape))
print("y Shape :{}".format(y.shape))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(X.shape[1], 128)
        self.fc2 = nn.Linear(128, 10)
        self.fc3 = nn.Linear(10, 2)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
device = torch.device("cuda") 
lr = 1
batch_size = 32
gamma = 0.7
epochs = 14
args = {'log_interval': 10, 'dry_run':False}

kwargs = {'batch_size': batch_size}
kwargs.update({'num_workers': 1,
               'pin_memory': True,
               'shuffle': True},
                 )

model = Net().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

my_dataset = TensorDataset(X,y) # create dataset
train_loader = DataLoader(my_dataset,**kwargs) #generate dataloader

cross_entropy_loss = torch.nn.CrossEntropyLoss()

for epoch in range(1, epochs + 1):
    ## Train step ##
    model.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data.float()) #HERE: why is .float() needed here?
        loss = cross_entropy_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % args['log_interval'] == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            if args['dry_run']:
                break
    
    scheduler.step()

In PyTorch, 64-bit floating point corresponds to torch.float64 or torch.double . While, 32-bit floating point corresponds to torch.float32 or torch.float .

Thus,

data is already a torch.float64 type

ie data is a 64 floating point type ( torch.double ).

By casting it using .float() , you convert it into 32-bit floating point.

a = torch.tensor([[1., -1.], [1., -1.]], dtype=torch.double)
print(a.dtype)                                                                                                                                                                                                                              
# torch.float64
print(a.float().dtype)                                                                                                                                
# torch.float32

Check different data types in PyTorch.

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM