簡體   English   中英

使用 pytorch 學習多元正態協方差矩陣

[英]Learning multivariate normal covariance matrix using pytorch

我正在嘗試使用一些觀察來學習多元正態協方差矩陣(Sigma,∑)。

我采用的方法是使用 pytorch.distributions.MultivariateNormal:

import torch
from torch.distributions import MultivariateNormal

# I tried both the scale_tril parameter and the covariance parameter.
mvn = MultivariateNormal(loc=torch.tensor([0.0, 0.0], requires_grad=False).view(1,2),
                        scale_tril=torch.tensor([[1.0 , 0.0], [0.0, 1.0]],
                                                requires_grad=True).view(-1, 2, 2))

loss = -mvn.log_prob(torch.ones((1, 2))).mean()
loss.backward()
print(mvn.loc.grad)

我沒有。 我嘗試擺弄 loc 和 scale_tril 參數的尺寸。 似乎沒有任何效果。 有任何想法嗎?

  • 我顯然可以自己實現這一點,但我非常喜歡使用現有工具。

最好的,埃亞爾。

你沒有在你的葉子節點上調用 .grad (在.view而不是張量本身),你也有requires_grad=False的意思,讓事情更明確

import torch
from torch.distributions import MultivariateNormal

mean = torch.tensor([0.0, 0.0], requires_grad=True)
cov = torch.tensor([[1.0 , 0.0], [0.0, 1.0]], requires_grad=True)

mvn = MultivariateNormal(loc=mean.view(1,2),
                         scale_tril=cov.view(-1, 2, 2))

loss = -mvn.log_prob(torch.ones((1, 2))).mean()
loss.backward()

print(mean.grad)
print(cov.grad)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM