簡體   English   中英

TensorFlow 梯度與 tf.where 在不應該返回 NaN

[英]TensorFlow gradient with tf.where returns NaN when it shouldn't

下面是可重現的代碼。 如果你運行它,你會看到在第一次 sess 運行中,結果是 nan,而第二種情況給出了正確的梯度值 0.5。 但是根據指定的 tf.where 和條件,它們應該返回相同的值。 我也根本不明白為什么 tf.where function 梯度是 nan 為 1 或 -1,這對我來說似乎是完全好的輸入值。

tf.reset_default_graph()
x = tf.get_variable('x', shape=[1])
condition = tf.less(x, 0.0)
output = tf.where(condition, -tf.log(-x + 1), tf.log(x + 1))
deriv = tf.gradients(output, x)
with tf.Session() as sess:
    print(sess.run(deriv, {x:np.array([-1])}))

logg = -tf.log(-x+1)
derivv = tf.gradients(logg, x)
with tf.Session() as sess:
    print(sess.run(derivv, {x:np.array([-1])}))

感謝您的評論!

正如@mikkola 提供的github 問題中所解釋的,問題源於tf.where的內部實現。 基本上,計算了兩個替代方案(及其梯度),並且通過乘法條件僅選擇正確的部分。 唉,如果選擇的部分的梯度是infnan ,即使乘以 0,您也會得到最終傳播到結果的nan

由於該問題已於 2016 年 5 月提交(即 tensorflow v0.7!)並且此后未修補,因此可以有把握地假設這不會很快出現並開始尋找解決方法。

修復它的最簡單方法是修改您的語句,使它們始終有效且可區分,即使對於不打算選擇的值也是如此。

一種通用技術是將輸入值裁剪在其有效域內。 因此,例如在您的情況下,您可以使用

cond = tf.less(x, 0.0)
output = tf.where(cond,
  -tf.log(-tf.where(cond, x, 0) + 1),
  tf.log(tf.where(cond, 0, x) + 1))

但是,在您的特定情況下,使用會更簡單

output = tf.sign(x) * tf.log(tf.abs(x) + 1)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM