简体   繁体   中英

How to calculate KL Divergence between two batches of distributions in Pytroch?

Given a batch of distributions, which represented as a pytorch tensor:

A = torch.tensor([[0., 0., 0., 0., 1., 5., 1., 2.],
        [0., 0., 1., 0., 4., 2., 1., 1.],
        [0., 0., 1., 1., 0., 5., 1., 1.],
        [0., 1., 1., 0., 2., 3., 1., 1.],
        [0., 0., 2., 1., 3., 1., 1., 0.],
        [0., 0., 2., 0., 5., 0., 1., 0.],
        [0., 2., 1., 4., 0., 0., 1., 0.],
        [0., 0., 2., 4., 1., 0., 1., 0.]], device='cuda:0')

A is a batch of distributions, which consists of eight distributions. Now, given another batch of distributions B :

B = torch.tensor([[0., 0., 1., 4., 2., 1., 1., 0.],
        [0., 0., 0., 5., 1., 2., 1., 0.],
        [0., 0., 0., 4., 2., 3., 0., 0.],
        [0., 0., 1., 7., 0., 0., 1., 0.],
        [0., 0., 1., 2., 4., 0., 1., 1.],
        [0., 0., 1., 3., 1., 3., 0., 0.],
        [0., 0., 1., 4., 1., 0., 2., 0.],
        [1., 0., 1., 5., 0., 1., 0., 0.],
        [0., 1., 5., 1., 0., 0., 1., 0.],
        [0., 0., 3., 2., 2., 0., 1., 0.],
        [0., 2., 4., 0., 1., 0., 1., 0.],
        [1., 0., 4., 1., 1., 1., 0., 0.]], device='cuda:0')

B has 12 distributions. I want to calculate the KL Divergence between each distribution in A and each distribution in B , and then obtain a KL Distance Matrix, of which shape is 12*8 . I know to use loop structure and torch.nn.functional.kl_div() to reach it. Is there any other methods in pytorch to implement it without using for-loop?

Here is my implementation using for-loop:

p_1 = F.softmax(A, dim = -1)
p_2 = F.softmax(B, dim = -1)
C = torch.empty(size = (A.shape[0], B.shape[0]), dtype = torch.float)

for i,a in enumerate(p_1):
    for j,b in enumerate(p_2):
        C[i][j] =  torch.nn.functional.kl_div(a.log(), b)
print(C)

Output is:

tensor([[0.4704, 0.5431, 0.3422, 0.6284, 0.3985, 0.2003, 0.4925, 0.5739, 0.5793,
         0.3992, 0.5007, 0.4934],
        [0.3416, 0.4518, 0.2950, 0.5263, 0.0218, 0.2254, 0.3786, 0.4747, 0.3626,
         0.1823, 0.2960, 0.2937],
        [0.3845, 0.4306, 0.2722, 0.5022, 0.4769, 0.1500, 0.3964, 0.4556, 0.4609,
         0.3396, 0.4076, 0.3933],
        [0.2862, 0.3752, 0.2116, 0.4520, 0.1307, 0.1116, 0.3102, 0.3990, 0.2869,
         0.1464, 0.2164, 0.2225],
        [0.1829, 0.2674, 0.1763, 0.3227, 0.0244, 0.1481, 0.2067, 0.2809, 0.1675,
         0.0482, 0.1271, 0.1210],
        [0.4359, 0.5615, 0.4427, 0.6268, 0.0325, 0.4160, 0.4749, 0.5774, 0.3492,
         0.2093, 0.3015, 0.3014],
        [0.0235, 0.0184, 0.0772, 0.0286, 0.3462, 0.1461, 0.0142, 0.0162, 0.3524,
         0.1824, 0.2844, 0.2988],
        [0.0097, 0.0171, 0.0680, 0.0284, 0.2517, 0.1374, 0.0082, 0.0148, 0.2403,
         0.1058, 0.2100, 0.1978]], device='cuda:0')

Looking at nn.KLDivLoss , the formula for computing the KL divergence is

kl = torch.mean(b * (torch.log(b) - a))

We can use broadcasting to compute the KL efficiently:

# avoid NaNs from log(0)
lB = B.clone()
lB[B==0] = 1.

# do the computation efficiently
C = (B[None, ...] * (torch.log(lB[None, ...]) - A[:, None, :])).mean(dim=-1)

Coming to think of it, I'm not sure what you are asking makes much sense. Your A and B tensors are filled with numbers, but they do not represent distributions (they do not sum to 1). Please consider carefully what you are trying to do here.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM