簡體   English   中英

為什么帶有 tf.where 的 tf.gradient 返回 None?

[英]Why tf.gradient with tf.where returns None?

我想根據某些條件更改輸入或/和輸出。 但畢業是無。

如何獲得漸變並保持選擇,如下面的代碼?

input1, input2 = ..., ...
output1, output2 = model(input1), model(input2)

input = tf.where(tf.less(output1, output2), input1, input2)
output = tf.where(tf.less(output1, output2), output1, output2)

grad, = tf.gradient(output, input)

我打印輸入和輸出的類型,它們的尺寸與輸入1/輸出1(輸入2/輸出2)相同。 如果我只計算 tf.gradient(output1,input1),那就沒問題了。 它們之間有什么區別?

您可以這樣做:

input1, input2 = ..., ...
output1, output2 = model(input1), model(input2)
mask = tf.less(output1, output2)
input = tf.where(mask, input1, input2)
output = tf.where(mask, output1, output2)
grad = tf.add(*tf.gradients(output, [input1, input2]))

tf.gradient將在尚未選擇相應輸入的位置返回兩個零的張量。 因此, grad將為input保留正確的聚合梯度。

您的原始方法不起作用,因為從TensorFlow的角度來看, inputoutput之間沒有依賴關系。 input從計算input1input2output也被從計算input1input2 ,但在從圖中沒有路徑inputoutput ,所以沒有梯度。

如果有人仍然堅持使用 NaN 輸入,tf.where 可以完全替換為:

tf.minimum(tensor_having_nans, value_that_replaces_nans)

此外, tf.maximum有效,並且漸變不會截斷。 在 'inf' 的情況下,只有tf.minimum有效

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM