简体   繁体   中英

Custom keras callbacks and changing weight (beta) of regularization term in variational autoencoder loss function

The variational autoencoder loss function is this: Loss = Loss_reconstruction + Beta * Loss_kld. I am trying to efficiently implement Kullback-Liebler Divergence Cyclic Annealing --that is changing the weight of beta dynamically during training. I subclass the tf.keras.callbacks.Callback class as a start, but I don't know how I can update a tf.keras.Model variable from a custom keras callback. Furthermore, I would like to track how the betas change at the end of each training step ( on_train_batch_end ), and right now I have a list in the callback class, but I know python lists don't play well with TensorFlow. When I fit the model, I get a warning that my on_train_batch_end function is slower than the processing of the batch itself. I think I should use a tf.TensorArray instead of python lists, but then the tf.TensorArray method write cannot use a tf.Variable for the index (ie, as the number of steps changes, the index in the tf.TensorArray to which a new beta for that step should be written changes)... is there a better way to store value changes? It looks like this github shows a solution that doesn't involve a custom tf.keras.Model and that uses a different kind of KL annealing. Below is a callback function and dummy VAE.

class CyclicAnnealing(tf.keras.callbacks.Callback):
  """Cyclic annealing from https://arxiv.org/abs/1903.10145
  
  Requires that model tracks training iterations and 
  total number of training iterations. It also requires
  that model has hyperparameter for `M` and `R`.
  """

  def __init__(self, schedule_fxn='sigmoid', **kwargs):
    super().__init__(**kwargs)

    # INEFFICIENT WAY OF LOGGING `betas` AND THE TRAIN STEPS...
    # The `train_iterations` list could be removed because in principle
    # if I have a list of betas, I know that the list of betas is of length
    # (number of samples//batch size) * number of epochs.
    # This is because (number of samples//batch size) * number of epochs is the total number of steps for the model.
    self.betas = []
    self.train_iterations = []

    if schedule_fxn == 'sigmoid':
      self.schedule_fxn = self.sigmoid

    elif schedule_fxn =='linear':
      self.schedule_fxn = self.linear

    else:
      raise ValueError('Invalid arg: `schedule_fxn`')

  def on_epoch_end(self, epoch, logs=None):
    print('\nCurrent anneal weight B =', self.beta)

  def on_train_batch_end(self, batch, logs=None):
    """Computes betas and updates list"""

    # Compute beta
    self.beta = self.beta_tau_cyclic_annealing(self.compute_tau())

    ###################################
    # HOW TO UPDATE BETA IN THE MODEL???
    ###################################

    # Update the lists for logging
    self.betas.append(self.beta)
    self.train_iterations.append(self.model._train_counter))

  def get_annealing_data(self):
    return {'betas': self.betas, 'training_iterations': self.train_iterations}

  def sigmoid(self, x):
    """Monotonic increasing function
    
    :return: tf.constant float32
    """

    return (1/(1+tf.keras.backend.exp(-x)))

  def linear(self, x):
    return x/self.model._R

  def compute_tau(self):
    """Used to determine kld_beta.
    
    :return: tf.constant float32
    """

    t = tf.identity(self.model._train_counter)
    T = self.model._total_training_iterations
    M = self.model._M
    numerator = tf.cast(tf.math.floormod(tf.subtract(t, 1), tf.math.floordiv(T, M)), dtype=tf.float32)
    denominator = tf.cast(tf.math.floordiv(T, M), dtype=tf.float32)
    return tf.math.divide(numerator, denominator)

  def beta_tau_cyclic_annealing(self, tau):
    """Compute change for kld_beta.
    
    :param tau: Increases beta_tau
    :param R: Proportion used to increase Beta w/i cycle.

    :return: tf.constant float32
    """

    R = self.model._R
    if tau <= R:
        return self.schedule_fxn(tau)
    else:
      return tf.constant(1.0)

Dummy vae:

class VAE(tf.keras.Model):
    def __init__(self, num_samples, batch_size, epochs, features, units, latent_size, kld_beta, M, R, **kwargs):
        """Defines state for model.

        :param num_samples: <class 'int'>
        :param batch_size: <class 'int'>
        :param epochs: <class 'int'>
        :param features: <class 'int'> if input is (n, m), then `features` is the the `m` dimension. This param is used with the decoder.
        :param units: <class 'int'> Number of hidden units.
        :param latent_size: <class 'int'> Dimension of latent space z.
        :param kld_beta: <tf.Variable??> for dynamic weight.
        :param M: <class 'int'> Hyperparameter for cyclic annealing.
        :param R: <class 'float'> Hyperparameter for cyclic annealing.
        """
        super().__init__(**kwargs)

        # NEED TO UPDATE THIS SOMEHOW -- I think it should be a tf.Variable?
        self.kld_beta = kld_beta

        # Hyperparameters for CyclicAnnealing
        self._M = M
        self._R = R
        self._total_training_iterations = (num_samples//batch_size) * epochs    

        # Encoder and Decoder not defined, but typically
        # encoder = inputs -> dense -> dense mu and dense log var -> z
        # while decoder = z -> dense -> reconstructions
        self.encoder = Encoder(units, latent_size)
        self.decoder = Decoder(features)

    def call(self, inputs):
        z, mus, log_vars = self.encoder(inputs)
        reconstructions = self.decoder(z)

        kl_loss = self.compute_kl_loss(mus, log_vars)

        # THE BETA WEIGHT NEEDS TO BE DYNAMIC
        weighted_kl_loss = self.kld_beta * kl_loss
      
        self.add_loss(weighted_kl_loss)

        return reconstructions
        
    def compute_kl_loss(self, mus, log_vars):
         return -0.5 * tf.reduce_mean(1. + log_vars - tf.exp(log_vars) - tf.pow(mus, 2))

Concerning your first question: It depends how you plan to update your gradients with your optimizer (eg ADAM). When training a VAE with Tensorflow / Keras, I usually use the @tf.function decorator to calculate the loss of my model and based on that update my model's parameters:

@tf.function
def train_step(self, model, batch, gamma, capacity):
    with tf.GradientTape() as tape:
        x, c = batch
        loss = compute_loss(model, x, c, gamma, capacity)
        tf.print('Total loss: ', loss)

    gradients = tape.gradient(loss, model.trainable_variables)
    self.optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Note the variables gamma and capacity. They are defined as terms which influence the loss function. I update them after an x number of epochs as follows:

new_weight = min(tf.keras.backend.get_value(capacity) + (20. / capacity_annealtime), 20.)
tf.keras.backend.set_value(capacity, new_weight)

At this point you can easily save the new_weight for logging purposes or you can defined a custom Tensorflow logger to log into a file. If you really want to use an array, you could simply define a TF array as:

this_array = tf.TensorArray(tf.float32, size=0, dynamic=True)

and update it after an x number of steps:

this_array.write(this_array.size(), new_beta_weight)

You could also use a second array and update it simultaneously in order to record the epoch or batch at which your new_beta_weight was updated.

Finally, the loss function itself looks like this:

def compute_loss(model, x, c, gamma_weight, capacity_weight):

  mean, logvar = model.encode(x, c)

  z = model.reparameterize(mean, logvar)
  reconstruction = model.decode(z, c)

  total_reconstruction_loss = 
  tf.nn.sigmoid_cross_entropy_with_logits(labels=x,                                                                      
  logits=reconstruction)
  
  total_reconstruction_loss = tf.reduce_sum(total_reconstruction_loss, 
   1)

  kl_loss = 1 + logvar - tf.square(mean) - tf.exp(logvar)
  kl_loss = tf.reduce_mean(kl_loss)
  kl_loss *= -0.5

  total_loss = tf.reduce_mean(total_reconstruction_loss * 3 + (
        gamma_weight * tf.abs(kl_loss - capacity_weight)))
  return total_loss

Note that model is from the type tf.keras.Model . This should hopefully give you some different insights into this specific topic.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM