簡體   English   中英

我應該使用哪種方法從正態分布中抽樣?

[英]Which method should I use to sample from a normal distribution?

我正在嘗試從 N-dim 標准高斯分布中采樣 batch_size 點。 但我注意到我可以使用兩個類似的函數,我想知道哪個是正確的或兩者都是。

假設我想從 2-dim 標准高斯分布中采樣 8 個點。

  1. torch.distributions.MultivariateNormal(torch.zeros(2), torch.eye(2)).sample([8]) ,它將返回一個大小為 [8,2] 的張量
  2. torch.randn(8,2)

它們返回相似的輸出,但我想知道它們是否相同。

torch.randn為您提供來自單變量標准正態分布的樣本,並將它們重塑為所需的形狀。 所以所有樣本的均值為0 ,具有單位方差。

x = torch.randn(1000000,2).numpy()
assert np.isclose(np.mean(x.flatten()), 0, atol=0.01)
plt.hist(x.flatten())

在此處輸入圖像描述

MultivariateNormal從多元正態分布生成樣本。 它可以通過均值向量和協方差矩陣進行參數化。

x = torch.distributions.MultivariateNormal(
    torch.zeros(2), torch.eye(2)).sample([10000]).numpy()
assert np.isclose(np.mean(x.flatten()), 0, atol=0.01)
plt.hist(x.flatten())

上面的用法有點像 hack; 我們在兩個不同的維度上創建兩個標准法線,並且由於分布(均值和方差)相同,我們可以認為它們也來自單個分布,因此扁平化數組可以被認為來自單變量標准法線。 在此處輸入圖像描述

MultivariateNormal的真正目的:

x = torch.distributions.MultivariateNormal(
    torch.tensor([-2.0, 2.0]), torch.eye(2)).sample([10000]).numpy()
plt.hist(x[:, 0])
plt.hist(x[:, 1])

在此處輸入圖像描述

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM