[英]Which method should I use to sample from a normal distribution?
我正在嘗試從 N-dim 標准高斯分布中采樣 batch_size 點。 但我注意到我可以使用兩個類似的函數,我想知道哪個是正確的或兩者都是。
假設我想從 2-dim 標准高斯分布中采樣 8 個點。
torch.distributions.MultivariateNormal(torch.zeros(2), torch.eye(2)).sample([8])
,它將返回一個大小為 [8,2] 的張量torch.randn(8,2)
它們返回相似的輸出,但我想知道它們是否相同。
torch.randn
為您提供來自單變量標准正態分布的樣本,並將它們重塑為所需的形狀。 所以所有樣本的均值為0
,具有單位方差。
x = torch.randn(1000000,2).numpy()
assert np.isclose(np.mean(x.flatten()), 0, atol=0.01)
plt.hist(x.flatten())
MultivariateNormal
從多元正態分布生成樣本。 它可以通過均值向量和協方差矩陣進行參數化。
x = torch.distributions.MultivariateNormal(
torch.zeros(2), torch.eye(2)).sample([10000]).numpy()
assert np.isclose(np.mean(x.flatten()), 0, atol=0.01)
plt.hist(x.flatten())
上面的用法有點像 hack; 我們在兩個不同的維度上創建兩個標准法線,並且由於分布(均值和方差)相同,我們可以認為它們也來自單個分布,因此扁平化數組可以被認為來自單變量標准法線。
MultivariateNormal
的真正目的:
x = torch.distributions.MultivariateNormal(
torch.tensor([-2.0, 2.0]), torch.eye(2)).sample([10000]).numpy()
plt.hist(x[:, 0])
plt.hist(x[:, 1])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.