簡體   English   中英

keras 變分自編碼器損失函數

[英]keras variational autoencoder loss function

我已經閱讀了 Keras 關於 VAE 實現的這篇博客,其中 VAE 損失是這樣定義的:

def vae_loss(x, x_decoded_mean):
    xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
    kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
    return xent_loss + kl_loss

我查看了Keras 文檔,VAE 損失函數是這樣定義的:在這個實現中,reconstruction_loss 乘以 original_dim,我在第一個實現中沒有看到!

if args.mse:
        reconstruction_loss = mse(inputs, outputs)
    else:
        reconstruction_loss = binary_crossentropy(inputs,
                                                  outputs)

    reconstruction_loss *= original_dim
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)

有人可以解釋為什么嗎? 謝謝!

first_one: CE + mean(kl, axis=-1) = CE + sum(kl, axis=-1) / d

second_one: d * CE + sum(kl, axis=-1)

所以: first_one = second_one / d

請注意,第二個返回所有樣本的平均損失,但第一個返回所有樣本的損失向量。

在 VAE 中,重建損失函數可以表示為:

reconstruction_loss = - log(p ( x | z))

如果解碼器輸出分布假設為高斯分布,則損失函數歸結為 MSE,因為:

reconstruction_loss = - log(p( x | z)) = - log ∏ ( N(x(i), x_out(i), sigma**2) = − ∑ log ( N(x(i), x_out(i), sigma**2) . alpha . ∑ (x(i), x_out(i))**2

相比之下,MSE 損失的方程為:

L(x,x_out) = MSE = 1/m ∑ (x(i) - x_out(i)) **2

其中 m 是輸出維度。 例如,在 MNIST 中 m = 寬 × 高 × 通道 = 28 × 28 × 1 = 784

因此,

reconstruction_loss = mse(inputs, outputs)

應該乘以 m(即原始維度)以等於 VAE 公式中的原始重建損失。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM