简体   繁体   中英

GAN LOSS of Generator and Discriminator Lowest at First Epoch - Is that normal?

I am trying to train a simple GAN and I noticed that the loss for the generator and discriminator is the lowest in the first epoch? How can that be? Did I miss something?

Below you find the plot of the Loss over the iterations: GAN的损失

Here is the code I was using:

I adapted the code according to your suggest @emrejik. It doesn't seem to have changed much though. I couldn't work with torch.ones() at the suggested lines as I was receiving this error message: "argument size (position 1) must be tuple of ints, not Tensor". Any idea how come?

I ran through 10 epochs and this came out now: LOSS of GAN_Second Try

from glob import glob
import sys
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.utils as vutils
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from mpl_toolkits.axes_grid1 import ImageGrid
from skimage import io
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
from tqdm import trange


manual_seed = 999

path = 'Punks'
image_paths = glob(path + '/*.png')
img_size = 28
batch_size = 32


device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose(
    [
        transforms.ToPILImage(),
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)


class ImageDataset(Dataset):
    def __init__(self, paths, transform):
        self.paths = paths
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index):
        image_path = self.paths[index]
        image = io.imread(image_path)

        if self.transform:
            image_tensor = self.transform(image)

        return image_tensor


   
if __name__ == '__main__':

   
    dataset = ImageDataset(image_paths, transform)

    train_loader = DataLoader(
        dataset, batch_size=batch_size, num_workers=2, shuffle=True)

 
       
    class Discriminator(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = nn.Sequential(
                nn.Linear(784*3, 2048),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(2048, 1024),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(1024, 512),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(512, 256),
                nn.ReLU(),
                nn.Dropout(0.3),
                nn.Linear(256, 1),
                nn.Sigmoid(),
            )

        def forward(self, x):
            x = x.view(x.size(0), 784*3)
            output = self.model(x)
            return output

    discriminator = Discriminator().to(device=device)

       
    class Generator(nn.Module):
        def __init__(self):
            super().__init__()
            self.model = nn.Sequential(
                nn.Linear(100, 256),
                nn.ReLU(),
                nn.Linear(256, 512),
                nn.ReLU(),
                nn.Linear(512, 1024),
                nn.ReLU(),
                nn.Linear(1024, 2048),
                nn.ReLU(),
                nn.Linear(2048, 784*3),
                nn.Tanh(),
            )

        def forward(self, x):
            output = self.model(x)
            output = output.view(x.size(0), 3, 28, 28)
            return output

    generator = Generator().to(device=device)

    lr = 0.0001
    num_epochs = 10
    loss_function = nn.BCELoss()
    #noise = torch.randn(batch_size, 100, device=device)

    optimizer_discriminator = torch.optim.Adam(
        discriminator.parameters(), lr=lr)
    optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)

    model = Discriminator().to(device=device)
    summary(model, input_size=(batch_size, 3, 28, 28))

    model = Generator().to(device=device)
    summary(model, input_size=(batch_size, 100))

    image_list = []
    Dis_losses = []
    Gen_losses = []
    iters = 0
    epochs = 0

    for epoch in trange((num_epochs), bar_format='{desc:<5.5}{percentage:3.0f}%|{bar:120}{r_bar}\n'):
        for n, real_samples in enumerate(train_loader):

            batch_size = len(real_samples)
            real_samples = real_samples.to(device=device)
            real_samples_labels = torch.ones((batch_size, 1)).to(device=device)

            latent_space_samples = torch.randn(
                (batch_size, 100)).to(device=device)

            fake_samples = generator(latent_space_samples)
            fake_samples_labels = torch.zeros(
                (batch_size, 1)).to(device=device)


            discriminator.zero_grad()

            output_discriminator_real = discriminator(real_samples) 
            loss_discriminator_real = loss_function(
                output_discriminator_real, real_samples_labels)  

            output_discriminator_fake = discriminator(fake_samples)  
            loss_discriminator_fake = loss_function(
                output_discriminator_fake, fake_samples_labels) 

            loss_discriminator = (
                loss_discriminator_real + loss_discriminator_fake)/2        

            loss_discriminator.backward()
            optimizer_discriminator.step()

            latent_space_samples = torch.randn(
                (batch_size, 100)).to(device=device)

            generator.zero_grad()

            fake_samples = generator(latent_space_samples)
            output_discriminator_fake = discriminator(fake_samples)
            loss_generator = loss_function(
                output_discriminator_fake, real_samples_labels) 
 
            loss_generator.backward()
            optimizer_generator.step()


            image_list.append(vutils.make_grid(
                fake_samples_labels, padding=2, normalize=True))
            Dis_losses.append(loss_discriminator.item())
            Gen_losses.append(loss_generator.item())
            iters += 1
            epochs += 1

            if n == batch_size - 1:
                print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
                print(f"Epoch: {epoch} Loss G.: {loss_generator}")

       
    latent_space_samples = torch.randn((batch_size, 100)).to(device=device)

    plt.figure(figsize=(10, 5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(Dis_losses, label="D")
    plt.plot(Gen_losses, label="G")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

I didn't see the proper use of loss function for the discriminator. You should give real samples and generated samples separately to the discriminator. I think you should change your code to a form like this:

fake = generator(noise)
disc_real = disc(real)
loss_disc_real = loss_func(disc_real, torch.ones_like(disc_real))
disc_fake = disc(fake)
loss_disc_fake = loss_func(disc_fake,torch.zeros_like(disc_fake))
loss_disc = (loss_disc_real+loss_disc_fake)/2
....

loss_generator = loss_func(disc_fake,torch.ones_like(disc_fake))
...

plot loss_disc and loss_generator, this should be work

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM