简体   繁体   English

Tensorflow JointDistributionSequential 样本的对数概率上的概率不兼容形状错误

[英]Tensorflow probability incompatible shapes error on log prob of JointDistributionSequential samples

Hi I'm trying to figure out whether I made a mistake with TFP shapes or if this is a TFP bug.嗨,我想弄清楚我是否在 TFP 形状上犯了错误,或者这是否是 TFP 错误。 I can sample from this simple joint distribution which uses a Normal prior on the mean of a 3-dim multivariate normal and 3 draws from a Half Cauchy as a prior on the diagonal of the MVN covariance.我可以从这个简单的联合分布中采样,它在 3-dim 多元正态的平均值上使用正态先验,在 MVN 协方差的对角线上使用半柯西的 3 个绘制作为先验。

I'm using tensorflow 2.2.0, and this bug occurs on both tensorflow-probability 0.10.1 and the nighly built tensorflow-probability 0.12.0-dev20200719.我正在使用 tensorflow 2.2.0,这个错误发生在 tensorflow-probability 0.10.1 和几乎构建的 tensorflow-probability 0.12.0-dev20200719 上。

This code should work standalone:此代码应独立工作:

import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions 

joint_model = tfd.JointDistributionSequential([
    tfd.Normal(loc=0., scale=1., name='z_0'),       
    tfd.HalfCauchy(loc=tf.zeros([3]), scale=2., name='lambda_k'),
    lambda lambda_k, z_0: tfd.MultivariateNormalDiag( # z_k ~ MVN(z_0, lambda_k)
        loc=z_0[...,tf.newaxis],
        scale_diag=lambda_k,
        name='z_k'),
])

# These work
joint_model.sample()
joint_model.sample(4)
joint_model.log_prob(joint_model.sample())

# This breaks 
joint_model.log_prob(joint_model.sample(4))

Here's the error message:这是错误消息:

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-41-61bae7ab690a> in <module>
      1 joint_model.log_prob(joint_model.sample())
      2 # ERROR
----> 3 joint_model.log_prob(joint_model.sample(4))

~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow_probability/python/distributions/joint_distribution.py in log_prob(self, *args, **kwargs)
    479         model_flatten_fn=self._model_flatten,
    480         model_unflatten_fn=self._model_unflatten)
--> 481     return self._call_log_prob(value, **unmatched_kwargs)
    482 
    483   # Override the base method to capture *args and **kwargs, so we can

~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow_probability/python/distributions/distribution.py in _call_log_prob(self, value, name, **kwargs)
    944     with self._name_and_control_scope(name, value, kwargs):
    945       if hasattr(self, '_log_prob'):
--> 946         return self._log_prob(value, **kwargs)
    947       if hasattr(self, '_prob'):
    948         return tf.math.log(self._prob(value, **kwargs))

~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow_probability/python/distributions/joint_distribution.py in _log_prob(self, value)
    391   def _log_prob(self, value):
    392     xs = self._map_measure_over_dists('log_prob', value)
--> 393     return sum(maybe_check_wont_broadcast(xs, self.validate_args))
    394 
    395   @distribution_util.AppendDocstring(kwargs_dict={

~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py in binary_op_wrapper(x, y)
    982     with ops.name_scope(None, op_name, [x, y]) as name:
    983       if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
--> 984         return func(x, y, name=name)
    985       elif not isinstance(y, sparse_tensor.SparseTensor):
    986         try:

~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py in _add_dispatch(x, y, name)
   1274     return gen_math_ops.add(x, y, name=name)
   1275   else:
-> 1276     return gen_math_ops.add_v2(x, y, name=name)
   1277 
   1278 

~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py in add_v2(x, y, name)
    478         pass  # Add nodes to the TensorFlow graph.
    479     except _core._NotOkStatusException as e:
--> 480       _ops.raise_from_not_ok_status(e, name)
    481   # Add nodes to the TensorFlow graph.
    482   _, _, _op, _outputs = _op_def_library._apply_op_helper(

~/miniconda3/envs/latent2/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
   6651   message = e.message + (" name: " + name if name is not None else "")
   6652   # pylint: disable=protected-access
-> 6653   six.raise_from(core._status_to_exception(e.code, message), None)
   6654   # pylint: enable=protected-access
   6655 

~/miniconda3/envs/latent2/lib/python3.7/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: Incompatible shapes: [4] vs. [4,3] [Op:AddV2]

This has been driving me absolutely insane so any help is appreciated这一直让我发疯所以任何帮助表示赞赏

Thanks @Miles Turpin, for sharing the solution reference.感谢@Miles Turpin,分享解决方案参考。 For the benefit of community I am providing solution here (answer section) given by jeffpollock9 in github .为了社区的利益,我在 github 中提供了jeffpollock9给出的解决方案(答案部分)。

There is a mix up between the batch_shape and event_shape in joint distribution, and it can be fixed by wrapping the half Cauchy distribution with tfd.Independent联合分布中的batch_shapeevent_shape之间存在混淆,可以通过用tfd.Independent包裹半 Cauchy 分布来修复

Please refer working code in below请参考下面的工作代码

    import tensorflow as tf
    import tensorflow_probability as tfp
    
    tfd = tfp.distributions
    
    joint_model = tfd.JointDistributionSequential([
        tfd.Normal(loc=0., scale=1., name='z_0'),
        tfd.Independent(tfd.HalfCauchy(loc=tf.zeros([3]), scale=2., name='lambda_k'), reinterpreted_batch_ndims=1),
        lambda lambda_k, z_0: tfd.MultivariateNormalDiag( # z_k ~ MVN(z_0, lambda_k)
            loc=z_0[...,tf.newaxis],
            scale_diag=lambda_k,
            name='z_k'),
    ])

print(joint_model)

print(joint_model.log_prob(joint_model.sample(4)))

Output: Output:

tfp.distributions.JointDistributionSequential("JointDistributionSequential", batch_shape=[[], [], []], event_shape=[[], [3], [3]], dtype=[float32, float32, float32])

tf.Tensor([-14.330933  -16.854149  -15.07704    -6.9233823], shape=(4,), dtype=float32)

Please refer this reasoning about Shapes and Probability Distributions.请参考这个关于形状和概率分布的推理。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM