简体   繁体   English

使用 Pytorch 分布与手动计算 KL 散度的不同结果

[英]Different results in computing KL Divergence using Pytorch Distributions vs manually

I noticed the KL-Divergence term KL(Q(x)||P(x)) is computed differently when using我注意到 KL-Divergence 项 KL(Q(x)||P(x)) 在使用时计算不同

mean(Q(x)*(log Q(x) - log P(x)))

vs对比

torch.distributions.kl_divergence(Q, P)

where在哪里

Q = torch.distributions.Normal(some mean, some sigma)
P = torch.distributions.Normal(0, 1)

and when I plot the KL-divergence losses, I get this two similar but different plots: here当我 plot KL 散度损失时,我得到了这两个相似但不同的图:这里

Can anyone point out what is causing this difference?谁能指出造成这种差异的原因?

The full code is below:完整代码如下:

import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt

def kl_1(log_qx, log_px):
    """
    inputs: [B, z_dim] torch
    """
    return (log_qx.exp() * (log_qx-log_px)).mean()

# ground-truth (target) P(x)
P = dist.Normal(0, 1)


mus = np.arange(-5, 5, 0.1)
sigma = 1
N = 100
kls = {"1": [], "2": []}
for mu in mus:
    # prediction (current) Q(x)
    Q = dist.Normal(mu, sigma)
    
    # sample from Q
    qx = Q.sample((N,))
        
    # log prob
    log_qx = Q.log_prob(qx)
    log_px = P.log_prob(qx)
    
    # kl 1
    kl1 = kl_1(log_qx, log_px)
    kls['1'].append(kl1.numpy())
    
    # kl 2
    kl2 = dist.kl_divergence(Q, P)
    kls['2'].append(kl2.numpy())
    
plt.figure()
plt.scatter(mus, kls['1'], label="Q*(logQ-logP)")
plt.scatter(mus, kls['2'], label="kl_divergence")
plt.xlabel("mean of Q(x)")
plt.ylabel("computed KL Divergence")
plt.legend()
plt.show()

You have the sample weighted by the probability density if you are computing the expected value from an integral on dx .如果您从dx上的积分计算期望值,则样本按概率密度加权。 If you are using a sample from the given distribution then you approximate the expected value as the mean directly, that corresponds to integration on d cq(x) thus d cq(x) = q(x) dx , where cq(x) is the cumulative probability function, and q(x) id the probability density funciton of the variable Q .如果您使用的是给定分布的样本,那么您可以直接将期望值近似为平均值,这对应于d cq(x)的积分,因此d cq(x) = q(x) dx ,其中cq(x)是累积概率 function, q(x)是变量Q的概率密度函数。

import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt

def kl_1(log_qx, log_px):
    """
    inputs: [B, z_dim] torch
    """
    return (log_qx-log_px).mean()

# ground-truth (target) P(x)
P = dist.Normal(0, 1)


mus = np.arange(-5, 5, 0.1)
sigma = 1
N = 100
kls = {"1": [], "2": []}
for mu in mus:
    # prediction (current) Q(x)
    Q = dist.Normal(mu, sigma)
    
    # sample from Q
    qx = Q.sample((N,))
        
    # log prob
    log_qx = Q.log_prob(qx)
    log_px = P.log_prob(qx)
    
    # kl 1
    kl1 = kl_1(log_qx, log_px)
    kls['1'].append(kl1.numpy())
    
    # kl 2
    kl2 = dist.kl_divergence(Q, P)
    kls['2'].append(kl2.numpy())
    
plt.figure()
plt.scatter(mus, kls['1'], label="Q*(logQ-logP)")
plt.scatter(mus, kls['2'], label="kl_divergence")
plt.xlabel("mean of Q(x)")
plt.ylabel("computed KL Divergence")
plt.legend()

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM