简体   繁体   English

TensorFlow 梯度与 tf.where 在不应该返回 NaN

[英]TensorFlow gradient with tf.where returns NaN when it shouldn't

Below is reproducible code.下面是可重现的代码。 If you run it, you will see that in the first sess run, the result is nan, whereas the second case gives the correct gradient value of 0.5.如果你运行它,你会看到在第一次 sess 运行中,结果是 nan,而第二种情况给出了正确的梯度值 0.5。 But per tf.where and condition specified, they should return the same value.但是根据指定的 tf.where 和条件,它们应该返回相同的值。 I also simply don't understand why the tf.where function gradient is nan at 1 or -1, which seem to be totally fine input values to me.我也根本不明白为什么 tf.where function 梯度是 nan 为 1 或 -1,这对我来说似乎是完全好的输入值。

tf.reset_default_graph()
x = tf.get_variable('x', shape=[1])
condition = tf.less(x, 0.0)
output = tf.where(condition, -tf.log(-x + 1), tf.log(x + 1))
deriv = tf.gradients(output, x)
with tf.Session() as sess:
    print(sess.run(deriv, {x:np.array([-1])}))

logg = -tf.log(-x+1)
derivv = tf.gradients(logg, x)
with tf.Session() as sess:
    print(sess.run(derivv, {x:np.array([-1])}))

Thanks for comments!感谢您的评论!

As explained in the github issue provided by @mikkola, the problem stems from the internal implementation of tf.where .正如@mikkola 提供的github 问题中所解释的,问题源于tf.where的内部实现。 Basically, both alternatives (and their gradient) are computed, and only the correct part is chosen by multiplication of the conditionnal.基本上,计算了两个替代方案(及其梯度),并且通过乘法条件仅选择正确的部分。 Alas, if the gradient is inf or nan for the part that is not selected, even when multiplied by 0 you get a nan that eventually propagates to the result.唉,如果选择的部分的梯度是infnan ,即使乘以 0,您也会得到最终传播到结果的nan

Since the issue has been filed in May 2016 (that's tensorflow v0.7!) and not patched since, one can safely assume that this won't be anytime soon and start looking for work around.由于该问题已于 2016 年 5 月提交(即 tensorflow v0.7!)并且此后未修补,因此可以有把握地假设这不会很快出现并开始寻找解决方法。

The easiest way to fix it is to modify your statements so that they always valid and differentiable, even for values that are not meant to be selected.修复它的最简单方法是修改您的语句,使它们始终有效且可区分,即使对于不打算选择的值也是如此。

A general technique would be to clip the input value inside its valid domain.一种通用技术是将输入值裁剪在其有效域内。 So in your case for example, you could use因此,例如在您的情况下,您可以使用

cond = tf.less(x, 0.0)
output = tf.where(cond,
  -tf.log(-tf.where(cond, x, 0) + 1),
  tf.log(tf.where(cond, 0, x) + 1))

In your particular case however it would be simpler to just use但是,在您的特定情况下,使用会更简单

output = tf.sign(x) * tf.log(tf.abs(x) + 1)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM