I am trying to train a simple GAN and I noticed that the loss for the generator and discriminator is the lowest in the first epoch? How can that be? Did I miss something?
Below you find the plot of the Loss over the iterations:
Here is the code I was using:
I adapted the code according to your suggest @emrejik. It doesn't seem to have changed much though. I couldn't work with torch.ones() at the suggested lines as I was receiving this error message: "argument size (position 1) must be tuple of ints, not Tensor". Any idea how come?
I ran through 10 epochs and this came out now:
from glob import glob
import sys
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torchvision.utils as vutils
import torchvision
from torch.utils.tensorboard import SummaryWriter
from torchinfo import summary
from mpl_toolkits.axes_grid1 import ImageGrid
from skimage import io
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import numpy as np
from tqdm import trange
manual_seed = 999
path = 'Punks'
image_paths = glob(path + '/*.png')
img_size = 28
batch_size = 32
device = 'cuda' if torch.cuda.is_available() else 'cpu'
transform = transforms.Compose(
[
transforms.ToPILImage(),
transforms.Resize(img_size),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
class ImageDataset(Dataset):
def __init__(self, paths, transform):
self.paths = paths
self.transform = transform
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
image_path = self.paths[index]
image = io.imread(image_path)
if self.transform:
image_tensor = self.transform(image)
return image_tensor
if __name__ == '__main__':
dataset = ImageDataset(image_paths, transform)
train_loader = DataLoader(
dataset, batch_size=batch_size, num_workers=2, shuffle=True)
class Discriminator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(784*3, 2048),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(2048, 1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, x):
x = x.view(x.size(0), 784*3)
output = self.model(x)
return output
discriminator = Discriminator().to(device=device)
class Generator(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 2048),
nn.ReLU(),
nn.Linear(2048, 784*3),
nn.Tanh(),
)
def forward(self, x):
output = self.model(x)
output = output.view(x.size(0), 3, 28, 28)
return output
generator = Generator().to(device=device)
lr = 0.0001
num_epochs = 10
loss_function = nn.BCELoss()
#noise = torch.randn(batch_size, 100, device=device)
optimizer_discriminator = torch.optim.Adam(
discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)
model = Discriminator().to(device=device)
summary(model, input_size=(batch_size, 3, 28, 28))
model = Generator().to(device=device)
summary(model, input_size=(batch_size, 100))
image_list = []
Dis_losses = []
Gen_losses = []
iters = 0
epochs = 0
for epoch in trange((num_epochs), bar_format='{desc:<5.5}{percentage:3.0f}%|{bar:120}{r_bar}\n'):
for n, real_samples in enumerate(train_loader):
batch_size = len(real_samples)
real_samples = real_samples.to(device=device)
real_samples_labels = torch.ones((batch_size, 1)).to(device=device)
latent_space_samples = torch.randn(
(batch_size, 100)).to(device=device)
fake_samples = generator(latent_space_samples)
fake_samples_labels = torch.zeros(
(batch_size, 1)).to(device=device)
discriminator.zero_grad()
output_discriminator_real = discriminator(real_samples)
loss_discriminator_real = loss_function(
output_discriminator_real, real_samples_labels)
output_discriminator_fake = discriminator(fake_samples)
loss_discriminator_fake = loss_function(
output_discriminator_fake, fake_samples_labels)
loss_discriminator = (
loss_discriminator_real + loss_discriminator_fake)/2
loss_discriminator.backward()
optimizer_discriminator.step()
latent_space_samples = torch.randn(
(batch_size, 100)).to(device=device)
generator.zero_grad()
fake_samples = generator(latent_space_samples)
output_discriminator_fake = discriminator(fake_samples)
loss_generator = loss_function(
output_discriminator_fake, real_samples_labels)
loss_generator.backward()
optimizer_generator.step()
image_list.append(vutils.make_grid(
fake_samples_labels, padding=2, normalize=True))
Dis_losses.append(loss_discriminator.item())
Gen_losses.append(loss_generator.item())
iters += 1
epochs += 1
if n == batch_size - 1:
print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
print(f"Epoch: {epoch} Loss G.: {loss_generator}")
latent_space_samples = torch.randn((batch_size, 100)).to(device=device)
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(Dis_losses, label="D")
plt.plot(Gen_losses, label="G")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
I didn't see the proper use of loss function for the discriminator. You should give real samples and generated samples separately to the discriminator. I think you should change your code to a form like this:
fake = generator(noise)
disc_real = disc(real)
loss_disc_real = loss_func(disc_real, torch.ones_like(disc_real))
disc_fake = disc(fake)
loss_disc_fake = loss_func(disc_fake,torch.zeros_like(disc_fake))
loss_disc = (loss_disc_real+loss_disc_fake)/2
....
loss_generator = loss_func(disc_fake,torch.ones_like(disc_fake))
...
plot loss_disc and loss_generator, this should be work
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.