简体   繁体   中英

tensorflow_probability: Gradients always zero when backpropagating the log_prob of a sample of a normal distribution

As part of a project I am having trouble with the gradients of a normal distribution with tensorflow_probability. For this I create a normal distribution of which a sample is drawn. The log_prob of this sample shall then be fed into an optimizer to update the weights of network.

If I get the log_prob of some constant I always get non-zero gradients. Unfortunately I have not found any relevant help in tutorials or similar sources of help.

def get_log_prob(mu, std)
   extracted_location = tf.squeeze(extracted_location)
   normal = tfd.Normal(mu, scale=std)
   samples = normal.sample(sample_shape=(1))
   log_prob = normal.log_prob(samples)
   return log_prob

const = tf.constant([0.1], dtype=np.float32)

log_prob = get_log_prob(const, 0.01)
grads = tf.gradients(log_prob, const)

with tf.Session() as sess:
   gradients = sess.run([grads])


print('gradients', gradients)

Output: gradients [array([0.], dtype=float32)]

I expect to get non-zero gradients if when computing the gradient of a sample. Instead the output is always "0."

This is a consequence of TensorFlow Probability implementing reparameterization gradients (aka the "reparameterization trick", and in fact is the correct answer in certain situations. Let me show you how that 0. answer comes about.

One way to generate a sample from a normal distribution with some location and scale is to first generate a sample from a standard normal distribution (this is usually some library provided function, eg tf.random.normal in TensorFlow) and then shift and scale it. Eg let's say the output of tf.random.normal is z . To get a sample x from the normal distribution with location loc and scale scale , you'd do: x = z * scale + loc .

Now, how does one compute value of the probability density of a number under the normal distribution? One way to do it is to reverse that transformation, so that you're now dealing with a standard normal distribution, and then compute the log-probability density there. Ie log_prob(x) = log_prob_std_normal((x - loc) / scale) + f(scale) (the f(scale) term comes about from the change of variables involved in the transformation, it's form doesn't matter for this explanation).

You can now plug in the first expression into the second, you'll get log_prob(x) = log_prob_std_normal(z) + f(scale) , ie the loc cancelled entirely! As a result, the gradient of log_prob with respect to loc is 0. . This also explains why you don't get a 0. if you evaluate the log probability at a constant: it'll be missing the forward transformation used to create the sample and you'll get some (typically) non-zero gradient.

So, when is this the correct behavior? The reparameterization gradients are correct when you're computing gradients of the distribution parameters with respect to an expectation of a function under that distribution. One way to compute such an expectation is to do a Monte-Carlo approximation, like so: tf.reduce_mean(g(dist.sample(N), axis=0) . It sounds like that's what you're doing (where your g() is log_prob() ), so it looks like the gradients are correct.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM