简体   繁体   中英

Keras-WGAN Critic And Generator Accuracy Stuck At 0

I am trying to implement WGAN in Keras. I am using David Foster's Generative Deep Learning Book and this code as reference. I wrote down this simple code. However, whenever I start training the model, the accuracy is always 0 and the losses for Critic and Discriminator are ~0.

They are stuck at these number no matter how many epochs they train for. I tried various network configurations and different hyperparameters, but the result don't seem to change. Google did not help much either. I cannot pin down the source of this behavior.

This is the code I wrote.


from os.path import expanduser
import os
import struct as st

import numpy as np
import matplotlib.pyplot as plt

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import RMSprop
import keras.backend as K

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

class WGAN:

    def __init__(self):

        # Data Params
        self.genInput=100
        self.imChannels=1
        self.imShape = (28,28,1)

        # Build Models
        self.onBuildDiscriminator()
        self.onBuildGenerator()
        self.onBuildGAN()

        pass

    def onBuildGAN(self):

        if self.mGenerator is None or self.mDiscriminator is None: raise Exception('Generator Or Descriminator Uninitialized.')

        self.mDiscriminator.trainable=False

        self.mGAN=Sequential()
        self.mGAN.add(self.mGenerator)
        self.mGAN.add(self.mDiscriminator)

        ganOptimizer=RMSprop(lr=0.00005)
        self.mGAN.compile(loss=wasserstein_loss, optimizer=ganOptimizer, metrics=['accuracy'])

        print('GAN Model')
        self.mGAN.summary()
        pass

    def onBuildGenerator(self):

        self.mGenerator=Sequential()

        self.mGenerator.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.genInput))
        self.mGenerator.add(Reshape((7, 7, 128)))
        self.mGenerator.add(UpSampling2D())
        self.mGenerator.add(Conv2D(128, kernel_size=4, padding="same"))
        self.mGenerator.add(BatchNormalization(momentum=0.8))
        self.mGenerator.add(Activation("relu"))
        self.mGenerator.add(UpSampling2D())
        self.mGenerator.add(Conv2D(64, kernel_size=4, padding="same"))
        self.mGenerator.add(BatchNormalization(momentum=0.8))
        self.mGenerator.add(Activation("relu"))
        self.mGenerator.add(Conv2D(self.imChannels, kernel_size=4, padding="same"))
        self.mGenerator.add(Activation("tanh"))

        print('Generator Model')
        self.mGenerator.summary()
        pass

    def onBuildDiscriminator(self):

        self.mDiscriminator = Sequential()

        self.mDiscriminator.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.imShape, padding="same"))
        self.mDiscriminator.add(LeakyReLU(alpha=0.2))
        self.mDiscriminator.add(Dropout(0.25))
        self.mDiscriminator.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
        self.mDiscriminator.add(ZeroPadding2D(padding=((0,1),(0,1))))
        self.mDiscriminator.add(BatchNormalization(momentum=0.8))
        self.mDiscriminator.add(LeakyReLU(alpha=0.2))
        self.mDiscriminator.add(Dropout(0.25))
        self.mDiscriminator.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
        self.mDiscriminator.add(BatchNormalization(momentum=0.8))
        self.mDiscriminator.add(LeakyReLU(alpha=0.2))
        self.mDiscriminator.add(Dropout(0.25))
        self.mDiscriminator.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
        self.mDiscriminator.add(BatchNormalization(momentum=0.8))
        self.mDiscriminator.add(LeakyReLU(alpha=0.2))
        self.mDiscriminator.add(Dropout(0.25))
        self.mDiscriminator.add(Flatten())
        self.mDiscriminator.add(Dense(1))

        disOptimizer=RMSprop(lr=0.00005)
        self.mDiscriminator.compile(loss=wasserstein_loss, optimizer=disOptimizer, metrics=['accuracy'])

        print('Discriminator Model')
        self.mDiscriminator.summary()

        pass

    def fit(self, trainData, nEpochs=1000, batchSize=64):

        lblForReal = -np.ones((batchSize, 1))
        lblForGene = np.ones((batchSize, 1))

        for ep in range(1, nEpochs+1):

            for __ in range(5):

                # Get Valid Images
                validImages = trainData[ np.random.randint(0, trainData.shape[0], batchSize) ]

                # Get Generated Images
                noiseForGene=np.random.normal(0, 1, size=(batchSize, self.genInput))
                geneImages=self.mGenerator.predict(noiseForGene)

                # Train Critic On Valid And Generated Images With Labels -1 And 1 Respectively
                disValidLoss=self.mDiscriminator.train_on_batch(validImages, lblForReal)
                disGeneLoss=self.mDiscriminator.train_on_batch(geneImages, lblForGene)

                # Perform Critic Weight Clipping
                for l in self.mDiscriminator.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -0.01, 0.01) for w in weights]
                    l.set_weights(weights)

            # Train Generator Using Combined Model
            geneLoss=self.mGAN.train_on_batch(noiseForGene, lblForReal)

            print(' Epoch', ep, 'Critic Valid Loss,Acc', disValidLoss, 'Critic Generated Loss,Acc', disGeneLoss, 'Generator Loss,Acc', geneLoss)
        pass

    pass

if __name__ == '__main__':
    (trainData, __), (__, __) = mnist.load_data()
    trainData = (trainData.astype(np.float32)/127.5) - 1
    trainData = np.expand_dims(trainData, axis=3)

    WGan = WGAN()
    WGan.fit(trainData)

I get output very similar to the following for all configs that I try.


 Epoch 1 Critic Valid Loss,Acc [-0.00016362152, 0.0] Critic Generated Loss,Acc [0.0003417502, 0.0] Generator Loss,Acc [-0.00016735379, 0.0]
 Epoch 2 Critic Valid Loss,Acc [-0.0001719332, 0.0] Critic Generated Loss,Acc [0.0003365979, 0.0] Generator Loss,Acc [-0.00017250411, 0.0]
 Epoch 3 Critic Valid Loss,Acc [-0.00017473527, 0.0] Critic Generated Loss,Acc [0.00032945914, 0.0] Generator Loss,Acc [-0.00017612436, 0.0]
 Epoch 4 Critic Valid Loss,Acc [-0.00017181305, 0.0] Critic Generated Loss,Acc [0.0003266656, 0.0] Generator Loss,Acc [-0.00016987178, 0.0]
 Epoch 5 Critic Valid Loss,Acc [-0.0001683443, 0.0] Critic Generated Loss,Acc [0.00032702673, 0.0] Generator Loss,Acc [-0.00016638976, 0.0]
 Epoch 6 Critic Valid Loss,Acc [-0.00017005506, 0.0] Critic Generated Loss,Acc [0.00032805002, 0.0] Generator Loss,Acc [-0.00017040147, 0.0]
 Epoch 7 Critic Valid Loss,Acc [-0.00017353195, 0.0] Critic Generated Loss,Acc [0.00033711304, 0.0] Generator Loss,Acc [-0.00017537423, 0.0]
 Epoch 8 Critic Valid Loss,Acc [-0.00017059325, 0.0] Critic Generated Loss,Acc [0.0003263024, 0.0] Generator Loss,Acc [-0.00016974319, 0.0]
 Epoch 9 Critic Valid Loss,Acc [-0.00017530039, 0.0] Critic Generated Loss,Acc [0.00032463064, 0.0] Generator Loss,Acc [-0.00017845634, 0.0]
 Epoch 10 Critic Valid Loss,Acc [-0.00017530067, 0.0] Critic Generated Loss,Acc [0.00033131015, 0.0] Generator Loss,Acc [-0.00017526663, 0.0]

I ran into a similar problem. The issue with WGAN is that the weight clipping method really cripples the model's ability to learn. The learning can saturate very quick. Weights are updates via backprop after every epoch but then they are clipped. I would suggest that you experiment with the clipping value more to extremes. Try [-1,1] and [-0.0001, 0.0001]. You will surely see a change. An example of saturating: 100000 个 epoch 的 WGAN Critic 损失

As you can see, loss value went to 0.999975 in the first few hundred iterations and then didn't move at all for 100000 iterations. I tried experimenting with different clipping values, the loss values were different but the behavior was same. When I tried [-0.005, 0.005], the loss saturated at around 1, for [-0.02, 0.02] around 0.8.

Your implementation looks correct but sometimes in GANs there's only so much you can do. So I suggest you try WGAN with gradient penalty. It has a nice method of enforcing K-Lipschitz continuity by fixing the L2-norm of the interpolated image as close to 1 (check out the paper ). For evaluation in WGAN-GP, ideally you should see the critic's loss value start at some large negative number and then converge to 0.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM