[英]How to replace certain values in Tensorflow tensor with the values of the other tensor?
[英]How to replace certain parts of a tensor on the condition in keras?
我想在具有 tensorflow 后端的 keras 張量上執行類似於 np.where 的操作。 這意味着假設我有兩個張量:diff 和 sum。 我將這些向量划分為:
rel_dev = diff / sum
對於 np.arrays 我會寫:
rel_dev = np.where((diff == 0.0) & (sum == 0.0), 0.0, rel_dev)
rel_dev = np.where((diff != 0.0) & (sum == 0.0), np.sign(diff), rel_dev)
也就是說,例如,如果我在 diff 和 sum 中都為零,我希望我不會得到 np.Inf,但將 rel_dev 設置為零。 現在在帶有張量的 keras 中它不起作用。 我已經嘗試過 K.switch、K.set_value 等。據我了解,它適用於整個張量,但不適用於單獨的部分,對嗎? 它可以在不設置這些條件的情況下工作,但我不知道實際發生在哪里。 我還沒有成功調試它。
您能否告訴我如何在 Keras 中為 rel_dev 編寫兩個條件?
您可以像這樣在 Keras 中執行此操作:
import keras.backend as K
diff = K.constant([0, 1, 2, -2, 3, 0])
sum = K.constant([2, 4, 1, 0, 5, 0])
rel_dev = diff / sum
d0 = K.equal(diff, 0)
s0 = K.equal(sum, 0)
rel_dev = K.switch(d0 & s0, K.zeros_like(rel_dev), rel_dev)
rel_dev = K.switch(~d0 & s0, K.sign(diff), rel_dev)
print(K.eval(rel_dev))
# [ 0. 0.25 2. -1. 0.6 0. ]
編輯:上面的公式有一個隱蔽的問題,即即使結果是正確的, nan
值也會通過梯度傳播回來(即因為除以零得到inf
或nan
,並且將inf
或nan
乘以零得到nan
) . 事實上,如果你檢查梯度:
gd, gs = K.gradients(rel_dev, (diff, sum))
print(K.eval(gd))
# [0.5 0.25 1. nan 0.2 nan]
print(K.eval(gs))
# [-0. -0.0625 -2. nan -0.12 nan]
您可以用來避免這種情況的技巧是以不影響結果但阻止nan
值的方式更改除法中的sum
,例如:
import keras.backend as K
diff = K.constant([0, 1, 2, -2, 3, 0])
sum = K.constant([2, 4, 1, 0, 5, 0])
d0 = K.equal(diff, 0)
s0 = K.equal(sum, 0)
# sum zeros are replaced by ones on division
rel_dev = diff / K.switch(s0, K.ones_like(sum), sum)
rel_dev = K.switch(d0 & s0, K.zeros_like(rel_dev), rel_dev)
rel_dev = K.switch(~d0 & s0, K.sign(diff), rel_dev)
print(K.eval(rel_dev))
# [ 0. 0.25 2. -1. 0.6 0. ]
gd, gs = K.gradients(rel_dev, (diff, sum))
print(K.eval(gd))
# [0.5 0.25 1. 0. 0.2 0. ]
print(K.eval(gs))
# [-0. -0.0625 -2. 0. -0.12 0. ]
您可以使用 tensorflow 的where
function 對張量做您想做的事情。
是的,Tensorflow 的哪里是您要查找的內容,如果您將張量轉換為 nparray 對 nparray 執行所有操作,那么您可以使用 tensor.numpy() 這將返回張量的 numpy 數組。 您可以使用“tf.convert_to_tensor”API 將 numpy 數組返回到張量。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.