簡體   English   中英

如何在 keras 的條件下更換張量的某些部分?

[英]How to replace certain parts of a tensor on the condition in keras?

我想在具有 tensorflow 后端的 keras 張量上執行類似於 np.where 的操作。 這意味着假設我有兩個張量:diff 和 sum。 我將這些向量划分為:

rel_dev = diff / sum

對於 np.arrays 我會寫:

rel_dev = np.where((diff == 0.0) & (sum == 0.0), 0.0, rel_dev)
rel_dev = np.where((diff != 0.0) & (sum == 0.0), np.sign(diff), rel_dev)

也就是說,例如,如果我在 diff 和 sum 中都為零,我希望我不會得到 np.Inf,但將 rel_dev 設置為零。 現在在帶有張量的 keras 中它不起作用。 我已經嘗試過 K.switch、K.set_value 等。據我了解,它適用於整個張量,但不適用於單獨的部分,對嗎? 它可以在不設置這些條件的情況下工作,但我不知道實際發生在哪里。 我還沒有成功調試它。

您能否告訴我如何在 Keras 中為 rel_dev 編寫兩個條件?

您可以像這樣在 Keras 中執行此操作:

import keras.backend as K

diff = K.constant([0, 1, 2, -2, 3, 0])
sum = K.constant([2, 4, 1, 0, 5, 0])
rel_dev = diff / sum
d0 = K.equal(diff, 0)
s0 = K.equal(sum, 0)
rel_dev = K.switch(d0 & s0, K.zeros_like(rel_dev), rel_dev)
rel_dev = K.switch(~d0 & s0, K.sign(diff), rel_dev)
print(K.eval(rel_dev))
# [ 0.    0.25  2.   -1.    0.6   0.  ]

編輯:上面的公式有一個隱蔽的問題,即即使結果是正確的, nan值也會通過梯度傳播回來(即因為除以零得到infnan ,並且將infnan乘以零得到nan ) . 事實上,如果你檢查梯度:

gd, gs = K.gradients(rel_dev, (diff, sum))
print(K.eval(gd))
# [0.5  0.25 1.    nan 0.2   nan]
print(K.eval(gs))
# [-0.     -0.0625 -2.         nan -0.12       nan]

您可以用來避免這種情況的技巧是以不影響結果但阻止nan值的方式更改除法中的sum ,例如:

import keras.backend as K

diff = K.constant([0, 1, 2, -2, 3, 0])
sum = K.constant([2, 4, 1, 0, 5, 0])
d0 = K.equal(diff, 0)
s0 = K.equal(sum, 0)
# sum zeros are replaced by ones on division
rel_dev = diff / K.switch(s0, K.ones_like(sum), sum)
rel_dev = K.switch(d0 & s0, K.zeros_like(rel_dev), rel_dev)
rel_dev = K.switch(~d0 & s0, K.sign(diff), rel_dev)
print(K.eval(rel_dev))
# [ 0.    0.25  2.   -1.    0.6   0.  ]
gd, gs = K.gradients(rel_dev, (diff, sum))
print(K.eval(gd))
# [0.5  0.25 1.   0.   0.2  0.  ]
print(K.eval(gs))
# [-0.     -0.0625 -2.      0.     -0.12    0.    ]

您可以使用 tensorflow 的where function 對張量做您想做的事情。

是的,Tensorflow 的哪里是您要查找的內容,如果您將張量轉換為 nparray 對 nparray 執行所有操作,那么您可以使用 tensor.numpy() 這將返回張量的 numpy 數組。 您可以使用“tf.convert_to_tensor”API 將 numpy 數組返回到張量。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM