[英]python multivariate normal pdf 3d plot
我試圖顯示mnist數據集中所有零數字的多元正態pdf的3d圖。
from scipy.stats import multivariate_normal
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
my0 = np.mean(num_arrays[0],axis=0)
sigma0 = np.identity(784)
p0 = multivariate_normal(my0,sigma0)
X, Y = np.mgrid[-10:10:.1, -10:10:.1]
pos = np.empty(X.shape + (2,))
pos[:, :, 0] = X
pos[:, :, 1] = Y
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, p0_id.pdf(pos),cmap='viridis',linewidth=0)
我收到以下錯誤消息:
operands could not be broadcast together with shapes (200,200,2) (784,)
我在這里做錯了什么?
編輯:完整的錯誤消息
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-70-584e158fe420> in <module>()
13 fig = plt.figure()
14 ax = fig.gca(projection='3d')
---> 15 ax.plot_surface(X, Y, p0.pdf(pos),cmap='viridis',linewidth=0)
~\Anaconda3\lib\site-packages\scipy\stats\_multivariate.py in pdf(self, x)
608
609 def pdf(self, x):
--> 610 return np.exp(self.logpdf(x))
611
612 def rvs(self, size=1, random_state=None):
~\Anaconda3\lib\site-packages\scipy\stats\_multivariate.py in logpdf(self, x)
604 x = self._dist._process_quantiles(x, self.dim)
605 out = self._dist._logpdf(x, self.mean, self.cov_info.U,
--> 606 self.cov_info.log_pdet, self.cov_info.rank)
607 return _squeeze_output(out)
608
~\Anaconda3\lib\site-packages\scipy\stats\_multivariate.py in _logpdf(self, x, mean, prec_U, log_det_cov, rank)
452
453 """
--> 454 dev = x - mean
455 maha = np.sum(np.square(np.dot(dev, prec_U)), axis=-1)
456 return -0.5 * (rank * _LOG_2PI + log_det_cov + maha)
ValueError: operands could not be broadcast together with shapes (200,200,2) (784,)
我正在嘗試做一些相同的事情,並且唯一的想法是我發現對它的原始形狀有所了解,以計算該函數的點對點結果並使用Axes3D.scatter()
函數繪制該點,這是一個如何使用python獲得3D高斯形狀的示例
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import multivariate_normal
mean = np.array([1., 1.])
cov_matrix = np.array([[2., 0.], [0., 2.]])
fig = plt.figure()
ax = fig.gca(projection="3d")
x = np.linspace(-3., 3., 20)
y = np.linspace(-3., 3., 20)
for i in x:
for j in y:
ax.scatter(i, j, pdf_2d(i, j, multivariate_normal.pdf([i, j], mean=mean_, cov=cov_matrix_))
plt.show()
當meshgrid產生多維數組時,我認為pdf()
函數無法將矢量轉換應用於每個元素。
不幸的是,這是我發現對高斯形狀的唯一了解。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.