简体   繁体   English

使用 tf.GradientTape 的自定义 GAN 训练循环返回 [None] 作为生成器的梯度,而它适用于鉴别器

[英]Custom GAN training loop using tf.GradientTape returns [None] as gradients for generator while it works for discriminator

I am trying to train a GAN.我正在尝试训练 GAN。 Somehow the gradient for the generator returns None even though it returns gradients for the discriminator.不知何故,生成器的梯度返回 None,即使它返回鉴别器的梯度。 This leads to ValueError: No gradients provided for any variable: ['carrier_freq:0'].这会导致ValueError: No gradients provided for any variable: ['carrier_freq:0']. when the optimizer applies the gradients to the weights (in this case just a single weight and should be a single gradient).当优化器将梯度应用于权重时(在这种情况下,只有一个权重,应该是一个梯度)。 I can't seem to find the reason for that as the computation should be almost the same.我似乎找不到原因,因为计算应该几乎相同。

This is the code for the train step where the gradients of the generator return [None].这是生成器的梯度返回 [None] 的训练步骤的代码。

generator = make_generator()
discriminator = make_discriminator()

g_loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
g_optimizer = keras.optimizers.Adam(learning_rate=0.04)
d_loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
d_optimizer = keras.optimizers.Adam(learning_rate=0.03)

def train_step(train_set):
    # modulate or don't modulate sample
    for batch in train_set:
        # get a random DEMAND noise sample to mix with speech
        noise_indices = tf.random.uniform([batch_size], minval=0, maxval=len(demand_dataset), dtype=tf.int32)
        
        # labels of 0 representing legit samples
        legit_labels = tf.zeros(batch_size, dtype=tf.uint8)
        # labels of 1 representing adversarial samples
        adversarial_labels = tf.ones(batch_size, dtype=tf.uint8)
        # concat legit and adversarial labels
        concat_labels = tf.concat((legit_labels, adversarial_labels), axis=0)
        
        # calculate gradients
        with tf.GradientTape(persistent=True) as tape:
            legit_predictions = discriminator(legit_path(batch, noise_indices))
            adversarial_predictions = discriminator(adversarial_path(batch, noise_indices))
            # concat legit and adversarial predictions to match double batch of concat_labels
            d_predictions = tf.concat((legit_predictions, adversarial_predictions), axis=0)
            d_loss = d_loss_fn(concat_labels, d_predictions)
            g_loss = g_loss_fn(legit_labels, adversarial_predictions)
            print('Discriminator loss: ' + str(d_loss))
            print('Generator loss: ' + str(g_loss))
        d_grads = tape.gradient(d_loss, discriminator.trainable_weights)
        g_grads = tape.gradient(g_loss, generator.trainable_weights)
        print(g_grads)
        d_optimizer.apply_gradients(zip(d_grads, discriminator.trainable_weights))
        g_optimizer.apply_gradients(zip(g_grads, generator.trainable_weights))
        
        discriminator_loss(d_loss)
        generator_loss(g_loss)
    return d_loss, g_loss

Here some information about what happens there:这里有一些关于那里发生的事情的信息:
The discriminator's goal is distinguishing between legit and adversarial samples.鉴别器的目标是区分合法样本和对抗样本。 The discriminator receives double the batch.鉴别器接收双倍的批次。 Once the batch is preprocessed in a way that would be legit data and once again in a way that would produce adversarial data ie the data is passed through the generator and is modified there.一旦批处理以合法数据的方式进行预处理,并再次以产生对抗性数据的方式进行预处理,即数据通过生成器并在那里进行修改。
The generator only has a single weight right now and consists of addition and multiplication operations wrapped in lambda layers.生成器现在只有一个权重,由包装在 lambda 层中的加法和乘法运算组成。
The losses are calculated as BinaryCrossentropy between the labels and data.损失计算为标签和数据之间的 BinaryCrossentropy。 The discriminator receives the true labels that represent whether or not each sample was modified.鉴别器接收代表每个样本是否被修改的真实标签。 The generator loss is calculated similar but it only considers the samples that were modified and the labels that represent legit samples.生成器损失的计算类似,但它只考虑修改过的样本和代表合法样本的标签。 So it basically measures how how many adversarial samples are classified as legit by the discriminator.所以它基本上衡量了有多少对抗样本被鉴别器分类为合法样本。

Now on to the problem:现在解决问题:
Both loss calculation seem to work as they return a value.两种损失计算似乎都有效,因为它们返回了一个值。 Also the calculation of gradients works for the discriminator.梯度的计算也适用于鉴别器。 But the gradients of the generator return [None] .但是生成器的梯度返回[None] It should work quite similar to the calculation of the discriminator gradients as the difference is that the loss calculation only uses a subset of the data that is used for the discriminator loss.它应该与鉴别器梯度的计算非常相似,因为不同之处在于损失计算仅使用用于鉴别器损失的数据的子集。 Another thing is that the generator only has a single weight and consists of lambda layers doing multiplication and addition whereas the discriminator is a Dense net and has more than one weight.另一件事是,生成器只有一个权重,由 lambda 层组成,进行乘法和加法,而鉴别器是一个密集网络,具有多个权重。

Does anyone has an idea what the root of the problem could be?有谁知道问题的根源可能是什么?

I found the problem and a solution.我找到了问题和解决方案。 The problem couldn't be taken from the code provided in the question, still I want to write about the solution for the improbable event of someone having the same issue and background situation.无法从问题中提供的代码中解决问题,但我仍然想写下针对具有相同问题和背景情况的人的不可能事件的解决方案。

Problem: Generator weight was part of a tf.keras.layers.Lambda() layer.问题:生成器权重是tf.keras.layers.Lambda()层的一部分。 Suchs weight or variables will not be tracked for gradient calculation.此类权重或变量将不会被跟踪以进行梯度计算。 More information here: https://programming.vip/docs/day-6-tensorflow2-model-subclassing-api.html .更多信息请访问: https : //programming.vip/docs/day-6-tensorflow2-model-subclassing-api.html
Solution: Write custom layer inheriting the base layer class like linked above.解决方案:编写自定义层,继承上面链接的基础层类。

I think this is because you have not called the generator inside the GradientTape() .我认为这是因为您没有在GradientTape()中调用生成器。 As discriminator has been called twice (within with tf.GradientTape(persistent=True) as tape: , you should call generator as well. (say such as generator(noise, training=True) . This way, generator's gradient will be evaluated as well.由于鉴别器已被调用两次(在with tf.GradientTape(persistent=True) as tape:内,您也应该调用generator 。(比如generator(noise, training=True) 。这样,生成器的梯度将被评估为出色地。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM