简体   繁体   中英

In tensorflow-probability, how do I update a learnable prior used only in KL-divergence?

I'm working on a variational auto-encoder and I'd like the prior used in the KL-divergence regularization of the latent distribution to have its loc (mean) and scale (stddev) updated.

The below snippet is a contrived minimal example demonstrating what I'm trying to achieve. This starts to work but then just freezes after some random number of epochs (sometimes 1, sometimes 200, but usually around 7 or 8). There's no error message or anything.

loc = tf.Variable(tf.random.normal([ndim], stddev=0.1, dtype=tf.float32))
scale = tfp.util.TransformedVariable(
    tf.random.normal([ndim], mean=1.0, stddev=0.1, dtype=tf.float32),
    bijector=tfb.Chain([tfb.Shift(1e-5), tfb.Softplus(), tfb.Shift(0.5413)]))
prior = tfd.Independent(tfd.Normal(loc=loc, scale=scale), reinterpreted_batch_ndims=1)

_input = tfkl.Input(shape=(1,))
_loc = tfkl.Dense(ndim, name="loc_params")(_input)
_scale = tfkl.Dense(ndim, name="untransformed_scale_params")(_input)
_scale = tf.math.softplus(_scale + np.log(np.exp(1) - 1)) + 1e-5
_output = tfpl.DistributionLambda(
    make_distribution_fn=lambda t: tfd.Independent(tfd.Normal(loc=t[0], scale=t[1])),
    activity_regularizer=tfpl.KLDivergenceRegularizer(prior, use_exact_kl=True, weight=0.1)
)([_loc, _scale])
model = tf.keras.Model(_input, _output)
model.compile(optimizer='adam', loss=lambda y_true, model_out: -model_out.log_prob(y_true))
hist = model.fit(ds, epochs=N_EPOCHS, verbose=2)

I have a runnable gist here .

A more concrete example, and an architecture close to what I'm trying to update and simplify, is the tfp example for disentangled_vae. In its manual training loop, a new tfd.MultivariateNormalDiag is instantiated on every loop , though it is parameterized using persistent tf.Variables. I'm trying my best to avoid manual training loops, and I'm also trying to move to more Keras-like syntax, so I'd rather not do a direct port of this example.

Any advice is greatly appreciated. Thanks!


Edit: The activity_regularizer seems to work fine when attached to a latent (bottleneck) distribution. I have a more complete example in this Colab notebook . As this works in my architecture, I'm no longer in need of an answer.

However, I highly doubt having model fitting freeze is desirable behaviour, so this remains a problem.

As the machinery works in most circumstances, just not the contrived freezing example above, I no longer consider this a question that needs an answer.

I have reported the errorless freezing behaviour via the tensorflow-probability repository issues page. See here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM