[英]Using mixture of Multivariate Normal distributions with Tensorflow-probability.layers
我正在嘗試使用張量流概率層來創建多元正態分布的混合。 當我為此使用 IndependentNormal 層時,它工作正常,但是當我使用 MultivariateNormalTriL 層時,我遇到了 event_shape 的問題。 我將這些層與 MixtureSameFamily 層結合起來。 以下代碼應該很好地說明我的問題,並且應該在 google colab 中工作:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow.keras as keras
tfpl = tfp.layers
print(tf.__version__)
# >> '1.15.0-rc3'
# but I get the same result with extra warnings in 1.14.0
print(tfp.__version__)
# >> '0.7.0'
print(tfpl.MultivariateNormalTriL(100)(
keras.layers.Input(shape=tfpl.MultivariateNormalTriL.params_size(100))
))
# >> tfp.distributions.MultivariateNormalTriL("multivariate_normal_tri_l_4/MultivariateNormalTriL/MultivariateNormalTriL/",
# batch_shape=[?], event_shape=[100], dtype=float32)
print(tfpl.IndependentNormal((100,))(
keras.layers.Input(shape=(tfpl.IndependentNormal.params_size(100),))
))
# >> tfp.distributions.Independent("Independentindependent_normal_2/IndependentNormal/Normal/",
# batch_shape=[?], event_shape=[100], dtype=float32)
print(tfpl.MixtureSameFamily(16, tfpl.MultivariateNormalTriL(100))(
keras.layers.Input(shape=(16*tfpl.MultivariateNormalTriL.params_size(100),))
))
# >> tfp.distributions.MixtureSameFamily("mixture_same_family_2/MixtureSameFamily/MixtureSameFamily/",
# batch_shape=[?], event_shape=[?], dtype=float32)
print(tfpl.MixtureSameFamily(16, tfpl.IndependentNormal((100,)))(
keras.layers.Input(shape=(16*tfpl.IndependentNormal.params_size(100,),))
))
# >> tfp.distributions.MixtureSameFamily("mixture_same_family_3/MixtureSameFamily/MixtureSameFamily/",
# batch_shape=[?], event_shape=[100], dtype=float32)
盡管 MultivariateNormalTriL 和 IndependentNormal 具有相同的 batch_shape 和 event_shape,但將它們與 MixtureSameFamily 組合會導致不同的事件形狀。
所以我的問題是:為什么它們會導致不同的事件形狀,以及如何為具有不同(不一定是對角)協方差矩陣和 event_shape=[100] 的多元正態分布混合獲得一個層?
編輯:同樣的情況發生在 tensorflow 概率版本 0.8
我誤解了 MixtureSameFamily 層是如何工作的,所以在閱讀了所有相關層的代碼后,我想出了以下解決方案:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow.compat.v1 as tf1
import numpy as np
tfl = tfp.layers
tfd = tfp.distributions
class MixtureMultivariateNormalTriL(tfl.DistributionLambda):
""" Creates a mixture of multivariate normal distributions through tfd.Mixture """
def __init__(self, num_components, event_size, validate_args=False, scale='default', **kwargs):
"""
Initialize the MixtureMultivariateNormalTriL layer
:param num_components: Number of component distributions in the mixture (int)
:param event_size: Scalar `int` representing the size of single draw from this
distribution.
:param validate_args: Python `bool`, default `False`. When `True` distribution
parameters are checked for validity despite possibly degrading runtime
performance. When `False` invalid inputs may silently render incorrect
outputs.
Default value: False
:param scale: type of tfp.bijectors.ScaleTriL used for the multivariate normal distribution.
If 'default', we use tfp.bijectors.ScaleTriL(
diag_shift=np.array(1e-5, params.dtype.as_numpy_dtype()),
validate_args=validate_args)
(using the same convention as in tfpl.MultivariateNormalTriL)
If `exponential`, we use scale_tril = tfp.bijectors.ScaleTriL(
diag_bijector=tfp.bijectors.Exp(),
diag_shift=None,
validate_args=validate_args
)
Alternatively a tfp.bijectors.ScaleTriL object can be passed.
Default value: "default"
"""
kwargs.pop('make_distribution_fn', None)
super().__init__(
lambda t: MixtureMultivariateNormalTriL.new(t, num_components, event_size, validate_args, scale),
**kwargs
)
self._event_size = event_size
self._num_components = num_components
self._validate_args = False
self._scale = scale
@staticmethod
def new(params, num_components, event_size, validate_args=False, scale='default', name=None):
# we expect params to be of shape (batch_size, num_components, component_params_shape)
with tf1.name_scope(name, 'MixtureMultivariateNormalTriL',
[params, num_components, event_size]):
params = tf.convert_to_tensor(value=params, name='params', dtype_hint=tf.float32)
num_components = tf.convert_to_tensor(
value=num_components, name='num_components', dtype_hint=tf.int32)
mixture_dist = tfd.Categorical(logits=params[..., :num_components])
component_params = tf.reshape(
params[..., num_components:],
tf.concat([tf.shape(input=params)[:-1], [num_components, -1]],
axis=0)) # the parameters for the various components
params_per_component = tf.unstack(component_params, axis=1)
if scale == "default":
scale_tril = tfp.bijectors.ScaleTriL(
diag_shift=np.array(1e-5, params.dtype.as_numpy_dtype()),
validate_args=validate_args) # use same conventions as MultivariateNormalTriL
elif scale == "exponential":
scale_tril = tfp.bijectors.ScaleTriL(
diag_bijector=tfp.bijectors.Exp(validate_args=validate_args),
diag_shift=None,
validate_args=validate_args
)
else:
assert isinstance(scale, tfp.bijectors.ScaleTriL)
scale_tril = scale
# for some reason, tfp doesn't manage to infer the event_shape of out distributions
# putting applying the following bijector helps remedy this
reshape = tfp.bijectors.Reshape(event_shape_out=(event_size,))
distributions = [
reshape(
tfd.MultivariateNormalTriL(
loc=par[..., :event_size],
scale_tril=scale_tril(par[..., event_size:]),
validate_args=validate_args
)
)
for par in params_per_component
]
return tfd.Mixture(
mixture_dist,
distributions,
validate_args=validate_args
)
@staticmethod
def params_size(num_components, event_size, name=None):
with tf1.name_scope(name, "MixtureMultivariateNormalTriL_params_size",
[num_components, event_size]):
return num_components + num_components * tfl.MultivariateNormalTriL.params_size(event_size)
def get_config(self):
base_config = super().get_config()
base_config["num_components"] = self._num_components
base_config["event_size"] = self._event_size
base_config["scale"] = self._scale
base_config["validate_args"] = self._validate_args
return base_config
不過,我仍在努力對其進行全面測試。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.