简体   繁体   中英

keras variational loss function scale

I am very new to NN and tensorflow, recently I have been reading up on keras implementation of variational autoencoder, and I found two versions of loss functions:

version1:  
    def vae_loss(x, x_decoded_mean):
       recon_loss = original_dim * objectives.mse(x, x_decoded_mean)
       kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
       return recon_loss + kl_loss

version2:
        def vae_loss(x, x_decoded_mean):
           recon_loss = objectives.mse(x, x_decoded_mean)
           kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
           return recon_loss + kl_loss

if my understanding is correct, version 1 is a sum of loss and version 2 is mean loss across all samples in the same batch. so does the scale of loss affect learning result? I tried testing them out, and it largely affect my latent variable scale. so why is this and which form of loss function is correct?

update of my question: if I multiply original_dim with KL loss,

def vae_loss(x, x_decoded_mean):
    xent_loss = original_dim * objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1) *original_dim
    return xent_loss + kl_loss

the latent distribution looks like below: enter image description here

and decoded output looks like this: enter image description here

looks the encoder output does not contain any information. I am using mnist dataset, and the example from https://github.com/vvkv/Variational-Auto-Encoders/blob/master/Variational%2BAuto%2BEncoders.ipynb

Summing versus averaging the loss for each example in a batch will simply scale all loss terms proportionally. An equivalent change would be adjusting the learning rate. The important thing is that your normal loss magnitude multiplied by your learning rate do not lead to unstable learning.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM