简体   繁体   中英

No gradients provided Tensorflow Keras with custom Training Step

I am trying to experiment with different implementations of VAE in tensorflow Keras. In the following model I get an error that no gradients are being provided for any variables in any layer.

tfkl = tf.keras.layers

class sampling2(tfk.layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch_size = tf.shape(z_mean)[0]
        dim_z = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch_size, dim_z))
        z_sample = z_mean + tf.exp(0.5 * z_log_var) * epsilon
        return z_sample

class encoder2(tfk.layers.Layer):
    def __init__(self, latent_dim = 30, intermediate_dim = 200, name= 'encoder2', **kwargs):
        super(encoder2, self).__init__(name = name, **kwargs)
        self.dense_1 = tfkl.Dense(intermediate_dim, activation="relu") 
        self.dense_mean = tfkl.Dense(latent_dim)
        self.dense_log_var = tfkl.Dense(latent_dim)
        self.sampling = sampling2()

    def call(self, inputs):
        x = self.dense_1(inputs)
        z_mean = self.dense_mean(x)
        z_log_var = self.dense_log_var(x)
        z = self.sampling((z_mean, z_log_var))
        return z_mean, z_log_var, z

class decoder2(tfk.layers.Layer):
    def __init__(self, original_dim, intermediate_dim= 200, name = 'decoder2', **kwargs):
        super(decoder2, self).__init__(name = name, **kwargs)
        self.dense_1 = tfkl.Dense(intermediate_dim, activation='relu')     
        self.dense_output = tfkl.Dense(original_dim, activation = 'sigmoid')

    def call(self, inputs):
        x = self.dense_1(inputs)
        logits = self.dense_output(x)
        return logits
    
class VAE2(tfk.Model):
    def __init__(self, original_dim, intermediate_dim = 800, latent_dim = 50,
               name = 'VAE2', **kwargs):
        super(VAE2, self).__init__(name = name, **kwargs)
        self.original_dim = original_dim
        self.encoder = encoder2(latent_dim = latent_dim, intermediate_dim = intermediate_dim)
        self.decoder = decoder2(original_dim, intermediate_dim = intermediate_dim)
    
    def call(self,inputs):
        z_mean, z_log_var, z = self.encoder(inputs)
        reconstructed = self.decoder(z)
        

        return reconstructed

    def training_step(self, inputs):
        dense_train_batch = tf.sparse.to_dense(inputs)
        with tf.GradientTape() as tape:
            z_mean, z_log_var, z = self.encoder(dense_train_batch)
            reconstructed = self.decoder(z)
            kl_loss = -0.5 * tf.reduce_mean(z_log_var - tf.square(z_mean) - 
                                        tf.exp(z_log_var)+1)
            self.add_loss(kl_loss)
        grads = tape.gradient(loss, self.trainable_weights)
        optimizer.apply_gradients(zip(grads, self.trainable_weights))
        

loss_fn = keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
vae2 = VAE2(df_track_names_reduced.shape[0])
vae2.compile(optimizer=keras.optimizers.Adam(learning_rate = 0.001, amsgrad = True), loss = loss_fn)
vae2.fit(train_dataset, epochs =20)

Bellow I am providing the error message

ValueError: No gradients provided for any variable: ['VAE2/encoder2/dense_8/kernel:0','VAE2/encoder2/dense_8/bias:0', 'VAE2/encoder2/dense_9/kernel:0', 'VAE2/encoder2/dense_9/bias:0', 'VAE2/encoder2/dense_10/kernel:0', 'VAE2/encoder2/dense_10/bias:0', 'VAE2/decoder2 /dense_11/kernel:0', 'VAE2/decoder2/dense_11/bias:0', 'VAE2/decoder2/dense_12/kernel:0', 'VAE2/decoder2/dense_12/bias:0'].

You have to pass a loss tensor to tape.gradient not a function. Calculate binary loss, add it to kl_loss: loss = binary_loss + kl_loss and then pass to tape.gradient() .

If you apply gradients manually - you should not call model.compile() , model.fit() . Build your custom loop instead. See here: https://keras.io/guides/writing_a_training_loop_from_scratch/ .

But I don't think you really need applying gradients manually. I would just add kl_loss within call function. See here: https://keras.io/api/losses/

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM