[英]TensorFlow gradient with tf.where returns NaN when it shouldn't
下面是可重现的代码。 如果你运行它,你会看到在第一次 sess 运行中,结果是 nan,而第二种情况给出了正确的梯度值 0.5。 但是根据指定的 tf.where 和条件,它们应该返回相同的值。 我也根本不明白为什么 tf.where function 梯度是 nan 为 1 或 -1,这对我来说似乎是完全好的输入值。
tf.reset_default_graph()
x = tf.get_variable('x', shape=[1])
condition = tf.less(x, 0.0)
output = tf.where(condition, -tf.log(-x + 1), tf.log(x + 1))
deriv = tf.gradients(output, x)
with tf.Session() as sess:
print(sess.run(deriv, {x:np.array([-1])}))
logg = -tf.log(-x+1)
derivv = tf.gradients(logg, x)
with tf.Session() as sess:
print(sess.run(derivv, {x:np.array([-1])}))
感谢您的评论!
正如@mikkola 提供的github 问题中所解释的,问题源于tf.where
的内部实现。 基本上,计算了两个替代方案(及其梯度),并且通过乘法条件仅选择正确的部分。 唉,如果未选择的部分的梯度是inf
或nan
,即使乘以 0,您也会得到最终传播到结果的nan
。
由于该问题已于 2016 年 5 月提交(即 tensorflow v0.7!)并且此后未修补,因此可以有把握地假设这不会很快出现并开始寻找解决方法。
修复它的最简单方法是修改您的语句,使它们始终有效且可区分,即使对于不打算选择的值也是如此。
一种通用技术是将输入值裁剪在其有效域内。 因此,例如在您的情况下,您可以使用
cond = tf.less(x, 0.0)
output = tf.where(cond,
-tf.log(-tf.where(cond, x, 0) + 1),
tf.log(tf.where(cond, 0, x) + 1))
但是,在您的特定情况下,使用会更简单
output = tf.sign(x) * tf.log(tf.abs(x) + 1)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.