简体   繁体   中英

How do i make my PyTorch DCGAN code to run on a GPU?

I am trying to train a DCGAN on a GPU but as I am starting out with PyTorch I have tried to do something from the documentation and it works but I want to confirm if it is the correct way to do it as I have done cause I looked at many other questions about running it on the GPU but they are done in different ways.

import os
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.nn import (
    BatchNorm2d,
    BCELoss,
    Conv2d,
    ConvTranspose2d,
    LeakyReLU,
    Module,
    ReLU,
    Sequential,
    Sigmoid,
    Tanh,
)
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.utils import save_image
from tqdm import tqdm

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataset = DataLoader(
    CIFAR10(
        root="./Data",
        download=True,
        transform=transforms.Compose(
            [
                transforms.Resize(64),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        ),
    ),
    batch_size=64,
    shuffle=True,
    num_workers=2,
)

try:
    os.mkdir("./Models")
    os.mkdir("./Results")

except FileExistsError:
    pass


class Gen(Module):
    def __init__(self):
        super(Gen, self).__init__()
        self.main = Sequential(
            ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            BatchNorm2d(512),
            ReLU(True),
            ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            BatchNorm2d(256),
            ReLU(True),
            ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            BatchNorm2d(128),
            ReLU(True),
            ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            BatchNorm2d(64),
            ReLU(True),
            ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            Tanh(),
        )

    def forward(self, input):
        output = self.main(input)
        return output


class Dis(Module):
    def __init__(self):
        super(Dis, self).__init__()
        self.main = Sequential(
            Conv2d(3, 64, 4, 2, 1, bias=False),
            LeakyReLU(0.2, inplace=True),
            Conv2d(64, 128, 4, 2, 1, bias=False),
            BatchNorm2d(128),
            LeakyReLU(0.2, inplace=True),
            Conv2d(128, 256, 4, 2, 1, bias=False),
            BatchNorm2d(256),
            LeakyReLU(0.2, inplace=True),
            Conv2d(256, 512, 4, 2, 1, bias=False),
            BatchNorm2d(512),
            LeakyReLU(0.2, inplace=True),
            Conv2d(512, 1, 4, 1, 0, bias=False),
            Sigmoid(),
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)


def weights(obj):
    classname = obj.__class__.__name__
    if classname.find("Conv") != -1:
        obj.weight.data.normal_(0.0, 0.02)

    elif classname.find("BatchNorm") != -1:
        obj.weight.data.normal_(1.0, 0.02)
        obj.bias.data.fill_(0)


gen = Gen()
gen.apply(weights)
gen.cuda().cuda().to(device)

dis = Dis()
dis.apply(weights)
dis.cuda().to(device)

criterion = BCELoss()

optimizerDis = Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerGen = Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))

for epoch in range(25):
    for batch, data in enumerate(tqdm(dataset, total=len(dataset)), 0):
        dis.zero_grad()

        input = Variable(data[0]).cuda().to(device)
        target = Variable(torch.ones(input.size()[0])).cuda().to(device)
        output = dis(input).cuda().to(device)

        realError = criterion(output, target)

        noise = Variable(torch.randn(input.size()[0], 100, 1, 1)).cuda().to(device)
        fake = gen(noise).cuda().to(device)
        target = Variable(torch.zeros(input.size()[0])).cuda().to(device)
        output = dis(fake.detach()).cuda().to(device)

        fakeError = criterion(output, target)

        errD = realError + fakeError
        errD.backward()
        optimizerDis.step()

        gen.zero_grad()

        target = Variable(torch.ones(input.size()[0])).cuda().to(device)
        output = dis(fake).cuda().to(device)

        errG = criterion(output, target)
        errG.backward()
        optimizerGen.step()

        print(f"  {epoch+1}/25 Dis Loss: {errD.data:.4f} Gen Loss: {errG.data:.4f}")

save_image(data[0], "./Results/Real.png", normalize=True)
save_image(gen(noise).data, f"./Results/Fake{epoch+1}.png", normalize=True)
torch.save(gen, f"./Models/model{epoch+1}.pth")

A few comments about your code:

  1. What is the value of your device variable?
    Make sure is is torch.device('cuda:0') (or whatever you GPU's device id is).
    Otherwise, if your device is actually torch.device('cpu') then you run in CPU.
    See torch.device for more information.

  2. You removed the "model" part of your code, but you may have skipped an important part there: Have you moved your model to GPU as well? Usually a model contains many internal parameters (aka trainable weights) and you need them also on device.
    your code should also have

dis.to(device)
criterion.to(device)  # if your loss function also has trainable parameters

Note that unlike torch.tensor s, calling .to on a nn.Module is an "in place" operation .

  1. You have redundancy in your code: you do not have to call both .cuda() AND .to() .
    Calling .cuda() was the old way of moving things to GPU, but for a while now pytorch introduced .to() to make coding simpler.

  2. Since both your inputs, and model are on GPU, you do not need to explicitly move the outputs to device as well. Thus, you can replace output = dis(input).cuda().to(device) with just output = dis(input) .

  3. No need to use Variable explicitly. You can replace

noise = Variable(torch.randn(input.size()[0], 100, 1, 1)).cuda().to(device)

with

noise = torch.randn(input.size()[0], 100, 1, 1), device=input.device)

You can also use torch.zeros_like and torch.ones_like for the target variable:

target = torch.zeros_like(input)

Note that zeros_like and one_like take care for the device (and data type) for you - it will be the same as input 's.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM