简体   繁体   中英

How to do CIFAR-10 with PyTorch on CUDA?

I'm following the CIFAR-10 PyTorch tutorial at this pytorch page , and can't get PyTorch running on the GPU. The code is exactly as in the tutorial.

The error I get is

Traceback (most recent call last):
  File "(file path)/CIFAR10_tutorial.py", line 116, in <module>
   outputs = net(images)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
  File "(file path)/CIFAR10_tutorial.py", line 65, in forward
x = self.pool(F.relu(self.conv1(x)).cuda())
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/module.py", line 491, in __call__
result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.5/dist-packages/torch/nn/modules/conv.py", line 301, in forward
self.padding, self.dilation, self.groups)

My CUDA version is 9.0, Pytorch 0.4.0. I have used tensorflow-gpu on the machine, so I know CUDA is set up correctly. Where exactly must I use .cuda() and .to(device) as suggested in the tutorial?

I'm leaving an answer, in case anyone else is stuck on the same.

First, configure Pytorch to use the GPU if available

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

Then, in the init function, cast to gpu by calling .cuda() on every element of the NN, eg

self.conv1 = nn.Conv2d(3, 24, 5).cuda()
self.pool = nn.MaxPool2d(2, 2).cuda()

If you're not sure about the GPU, call .to(device) on every element.

In the forward(self, x) function, before the steps, I did

x = x.to(device)

Right after net object is created, cast it to device by

net.to(device)

All inputs and labels should be cast to device before any operation is performed on them.

inputs, labels = inputs.to(device), labels.to(device)

I am, skipping writing the entire code as the link has already been mentioned in the question. If there are seem to be a few redundant casts to gpu, they're not breaking anything. I might also put together an ipynb with the changes.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM