简体   繁体   English

tensorflow_probability 分布应该如何用于多维空间?

[英]How should tensorflow_probability distributions be used for multi-dimensional spaces?

I would like to create a multi-dimensional gaussian probability density function (let's say a 2D gaussian like in the figure below) with tensorflow.我想用 tensorflow 创建一个多维高斯概率密度 function(比如说像下图中的二维高斯)。

在此处输入图像描述

For 1D, it works like a charm:对于 1D,它就像一个魅力:

d = tfp.distributions.Normal(loc=5.0, scale=3.0)
x = d.prob(tf.range(0,10, dtype=tf.float32))

But for higher dimension, I get InvalidArgumentError: Incompatible shapes error using Normal or MultivariateNormalDiag distributions... What do I miss?但对于更高的维度,我得到InvalidArgumentError: Incompatible shapes error using Normal or MultivariateNormalDiag分布......我错过了什么? How should the prob method be used to output the probability density function on a multi dimensional tensor? prob方法应该如何用于 output 概率密度 function 在多维张量上?

If I understood correctly, you can do something like:如果我理解正确,您可以执行以下操作:

mu = [0,0]
cov = [[1,0],
       [0,1]]
mv_normal = np.random.multivariate_normal(mu, cov, size=1000)
mv_normal_mean = np.mean(mv_normal , axis=0)
mv_normal_cov = np.cov(mv_normal , rowvar=0)
mv_normal_diag = np.diag(mv_normal_cov)
mv_normal_stddev = np.sqrt(mv_normal_diag)

mv_normal is just like: mv_normal 就像:

mv_normal
array([[-1.73476374,  0.17578855],
       [ 0.11866498, -0.66417069],
       [ 1.52000069, -1.3004096 ],
       ...,
       [-1.37625595, -0.46864374],
       [ 0.81659449,  0.70524036],
       [ 1.12183633,  0.14196896]])

mv_normal_mean and mv_normal_cov etc are just arrays here. mv_normal_meanmv_normal_cov等在这里只是 arrays 。 They will be used to create:它们将用于创建:

 mvn = tfd.MultivariateNormalDiag(
 loc=mv_normal_mean,
 scale_diag=mv_normal_stddev)

Values can be seen as:值可以看作:

mvn_mean
array([-0.03976356,  0.07387231])

mv_normal_cov
array([[ 1.04138867, -0.00877481],
       [-0.00877481,  0.97736496]])

And you can use contour plot for plotting.您可以使用轮廓 plot 进行绘图。

x1, x2 = np.meshgrid(mv_normal[:,0], mv_normal[:,1])
data = np.stack((x1.flatten(), x2.flatten()), axis=1)
prob = mvn.prob(data).numpy()
plt.figure(figsize = (12,9))
ax = plt.axes(projection='3d')
ax.plot_surface(x1, x2, prob.reshape(x1.shape), cmap = 'Blues')
plt.show()

That will produce as follows:这将产生如下: 在此处输入图像描述

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM