繁体   English   中英

TensorFlow 梯度与 tf.where 在不应该返回 NaN

[英]TensorFlow gradient with tf.where returns NaN when it shouldn't

下面是可重现的代码。 如果你运行它,你会看到在第一次 sess 运行中,结果是 nan,而第二种情况给出了正确的梯度值 0.5。 但是根据指定的 tf.where 和条件,它们应该返回相同的值。 我也根本不明白为什么 tf.where function 梯度是 nan 为 1 或 -1,这对我来说似乎是完全好的输入值。

tf.reset_default_graph()
x = tf.get_variable('x', shape=[1])
condition = tf.less(x, 0.0)
output = tf.where(condition, -tf.log(-x + 1), tf.log(x + 1))
deriv = tf.gradients(output, x)
with tf.Session() as sess:
    print(sess.run(deriv, {x:np.array([-1])}))

logg = -tf.log(-x+1)
derivv = tf.gradients(logg, x)
with tf.Session() as sess:
    print(sess.run(derivv, {x:np.array([-1])}))

感谢您的评论!

正如@mikkola 提供的github 问题中所解释的,问题源于tf.where的内部实现。 基本上,计算了两个替代方案(及其梯度),并且通过乘法条件仅选择正确的部分。 唉,如果选择的部分的梯度是infnan ,即使乘以 0,您也会得到最终传播到结果的nan

由于该问题已于 2016 年 5 月提交(即 tensorflow v0.7!)并且此后未修补,因此可以有把握地假设这不会很快出现并开始寻找解决方法。

修复它的最简单方法是修改您的语句,使它们始终有效且可区分,即使对于不打算选择的值也是如此。

一种通用技术是将输入值裁剪在其有效域内。 因此,例如在您的情况下,您可以使用

cond = tf.less(x, 0.0)
output = tf.where(cond,
  -tf.log(-tf.where(cond, x, 0) + 1),
  tf.log(tf.where(cond, 0, x) + 1))

但是,在您的特定情况下,使用会更简单

output = tf.sign(x) * tf.log(tf.abs(x) + 1)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM