简体   繁体   中英

keras variational autoencoder loss function

I've read this blog by Keras on VAE implementation, where VAE loss is defined this way:

def vae_loss(x, x_decoded_mean):
    xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
    return xent_loss + kl_loss

I looked at the Keras documentation and the VAE loss function is defined this way: In this implementation, the reconstruction_loss is multiplied by original_dim, which I don't see in the first implementation!

if args.mse:
        reconstruction_loss = mse(inputs, outputs)
    else:
        reconstruction_loss = binary_crossentropy(inputs,
                                                  outputs)

    reconstruction_loss *= original_dim
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)

Can somebody please explain why? Thank you!

first_one: CE + mean(kl, axis=-1) = CE + sum(kl, axis=-1) / d

second_one: d * CE + sum(kl, axis=-1)

So: first_one = second_one / d

And note that the second one returns the mean loss over all the samples, but the first one returns a vector of losses for all samples.

In VAE, the reconstruction loss function can be expressed as:

reconstruction_loss = - log(p ( x | z))

If the decoder output distribution is assumed to be Gaussian, then the loss function boils down to MSE since:

reconstruction_loss = - log(p( x | z)) = - log ∏ ( N(x(i), x_out(i), sigma**2) = − ∑ log ( N(x(i), x_out(i), sigma**2) . alpha . ∑ (x(i), x_out(i))**2

In contrast, the equation for the MSE loss is:

L(x,x_out) = MSE = 1/m ∑ (x(i) - x_out(i)) **2

Where m is the output dimensions. for example, in MNIST m = width × height × channels = 28 × 28 × 1 = 784

Thus,

reconstruction_loss = mse(inputs, outputs)

should be multiplied by m (ie original dimension) to be equal to the original reconstruction loss in the VAE formulation.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM