简体   繁体   中英

Different results in computing KL Divergence using Pytorch Distributions vs manually

I noticed the KL-Divergence term KL(Q(x)||P(x)) is computed differently when using

mean(Q(x)*(log Q(x) - log P(x)))

vs

torch.distributions.kl_divergence(Q, P)

where

Q = torch.distributions.Normal(some mean, some sigma)
P = torch.distributions.Normal(0, 1)

and when I plot the KL-divergence losses, I get this two similar but different plots: here

Can anyone point out what is causing this difference?

The full code is below:

import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt

def kl_1(log_qx, log_px):
    """
    inputs: [B, z_dim] torch
    """
    return (log_qx.exp() * (log_qx-log_px)).mean()

# ground-truth (target) P(x)
P = dist.Normal(0, 1)


mus = np.arange(-5, 5, 0.1)
sigma = 1
N = 100
kls = {"1": [], "2": []}
for mu in mus:
    # prediction (current) Q(x)
    Q = dist.Normal(mu, sigma)
    
    # sample from Q
    qx = Q.sample((N,))
        
    # log prob
    log_qx = Q.log_prob(qx)
    log_px = P.log_prob(qx)
    
    # kl 1
    kl1 = kl_1(log_qx, log_px)
    kls['1'].append(kl1.numpy())
    
    # kl 2
    kl2 = dist.kl_divergence(Q, P)
    kls['2'].append(kl2.numpy())
    
plt.figure()
plt.scatter(mus, kls['1'], label="Q*(logQ-logP)")
plt.scatter(mus, kls['2'], label="kl_divergence")
plt.xlabel("mean of Q(x)")
plt.ylabel("computed KL Divergence")
plt.legend()
plt.show()

You have the sample weighted by the probability density if you are computing the expected value from an integral on dx . If you are using a sample from the given distribution then you approximate the expected value as the mean directly, that corresponds to integration on d cq(x) thus d cq(x) = q(x) dx , where cq(x) is the cumulative probability function, and q(x) id the probability density funciton of the variable Q .

import numpy as np
import torch
import torch.distributions as dist
import matplotlib.pyplot as plt

def kl_1(log_qx, log_px):
    """
    inputs: [B, z_dim] torch
    """
    return (log_qx-log_px).mean()

# ground-truth (target) P(x)
P = dist.Normal(0, 1)


mus = np.arange(-5, 5, 0.1)
sigma = 1
N = 100
kls = {"1": [], "2": []}
for mu in mus:
    # prediction (current) Q(x)
    Q = dist.Normal(mu, sigma)
    
    # sample from Q
    qx = Q.sample((N,))
        
    # log prob
    log_qx = Q.log_prob(qx)
    log_px = P.log_prob(qx)
    
    # kl 1
    kl1 = kl_1(log_qx, log_px)
    kls['1'].append(kl1.numpy())
    
    # kl 2
    kl2 = dist.kl_divergence(Q, P)
    kls['2'].append(kl2.numpy())
    
plt.figure()
plt.scatter(mus, kls['1'], label="Q*(logQ-logP)")
plt.scatter(mus, kls['2'], label="kl_divergence")
plt.xlabel("mean of Q(x)")
plt.ylabel("computed KL Divergence")
plt.legend()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM