简体   繁体   English

log_prob 与人工计算的差异

[英]Discrepancy between log_prob and manual calculation

I want to define multivariate normal distribution with mean [1, 1, 1] and variance covariance matrix with 0.3 on diagonal.我想定义均值[1, 1, 1]的多元正态分布和对角线为0.3的方差协方差矩阵。 After that I want to calculate log likelihood on datapoints [2, 3, 4]之后我想计算数据点[2, 3, 4]上的对数似然

By torch distributions通过火炬分发

import torch
import torch.distributions as td

input_x = torch.tensor([2, 3, 4])
loc = torch.ones(3)
scale = torch.eye(3) * 0.3
mvn = td.MultivariateNormal(loc = loc, scale_tril=scale)
mvn.log_prob(input_x)
tensor(-76.9227)

From scratch从头开始

By using formula for log likelihood:通过使用对数似然公式:

在此处输入图像描述

We obtain tensor:我们得到张量:

first_term = (2 * np.pi* 0.3)**(3)
first_term = -np.log(np.sqrt(first_term))
x_center = input_x - loc
tmp = torch.matmul(x_center, scale.inverse())
tmp = -1/2 * torch.matmul(tmp, x_center)
first_term + tmp 
tensor(-24.2842)

where I used fact that我在哪里使用事实在此处输入图像描述

My question is - what's the source of this discrepancy?我的问题是 - 这种差异的根源是什么?

You are passing the covariance matrix to the scale_tril instead of covariance_matrix .您正在将协方差矩阵传递给scale_tril而不是covariance_matrix From the docs of PyTorch's Multivariate Normal来自PyTorch 的 Multivariate Normal的文档

scale_tril (Tensor) – lower-triangular factor of covariance, with positive-valued diagonal scale_tril (Tensor) – 协方差的下三角因子,具有正值对角线

So, replacing scale_tril with covariance_matrix would yield the same results as your manual attempt.因此,用covariance_matrix替换scale_tril会产生与您手动尝试相同的结果。

In [1]: mvn = td.MultivariateNormal(loc = loc, covariance_matrix=scale)
In [2]: mvn.log_prob(input_x)
Out[2]: tensor(-24.2842)

However, it's more efficient to use scale_tril according to the authors:然而,根据作者的说法,使用scale_tril更有效:

...Using scale_tril will be more efficient: ...使用 scale_tril 会更有效率:

You can calculate the lower choelsky using torch.cholesky您可以使用 torch.cholesky 计算较低的torch.cholesky

In [3]: mvn = td.MultivariateNormal(loc = loc, scale_tril=torch.cholesky(scale))
In [4]: mvn.log_prob(input_x)
Out[4]: tensor(-24.2842)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM