![](/img/trans.png)
[英]How to replace certain values in Tensorflow tensor with the values of the other tensor?
[英]How to replace certain parts of a tensor on the condition in keras?
我想在具有 tensorflow 后端的 keras 张量上执行类似于 np.where 的操作。 这意味着假设我有两个张量:diff 和 sum。 我将这些向量划分为:
rel_dev = diff / sum
对于 np.arrays 我会写:
rel_dev = np.where((diff == 0.0) & (sum == 0.0), 0.0, rel_dev)
rel_dev = np.where((diff != 0.0) & (sum == 0.0), np.sign(diff), rel_dev)
也就是说,例如,如果我在 diff 和 sum 中都为零,我希望我不会得到 np.Inf,但将 rel_dev 设置为零。 现在在带有张量的 keras 中它不起作用。 我已经尝试过 K.switch、K.set_value 等。据我了解,它适用于整个张量,但不适用于单独的部分,对吗? 它可以在不设置这些条件的情况下工作,但我不知道实际发生在哪里。 我还没有成功调试它。
您能否告诉我如何在 Keras 中为 rel_dev 编写两个条件?
您可以像这样在 Keras 中执行此操作:
import keras.backend as K
diff = K.constant([0, 1, 2, -2, 3, 0])
sum = K.constant([2, 4, 1, 0, 5, 0])
rel_dev = diff / sum
d0 = K.equal(diff, 0)
s0 = K.equal(sum, 0)
rel_dev = K.switch(d0 & s0, K.zeros_like(rel_dev), rel_dev)
rel_dev = K.switch(~d0 & s0, K.sign(diff), rel_dev)
print(K.eval(rel_dev))
# [ 0. 0.25 2. -1. 0.6 0. ]
编辑:上面的公式有一个隐蔽的问题,即即使结果是正确的, nan
值也会通过梯度传播回来(即因为除以零得到inf
或nan
,并且将inf
或nan
乘以零得到nan
) . 事实上,如果你检查梯度:
gd, gs = K.gradients(rel_dev, (diff, sum))
print(K.eval(gd))
# [0.5 0.25 1. nan 0.2 nan]
print(K.eval(gs))
# [-0. -0.0625 -2. nan -0.12 nan]
您可以用来避免这种情况的技巧是以不影响结果但阻止nan
值的方式更改除法中的sum
,例如:
import keras.backend as K
diff = K.constant([0, 1, 2, -2, 3, 0])
sum = K.constant([2, 4, 1, 0, 5, 0])
d0 = K.equal(diff, 0)
s0 = K.equal(sum, 0)
# sum zeros are replaced by ones on division
rel_dev = diff / K.switch(s0, K.ones_like(sum), sum)
rel_dev = K.switch(d0 & s0, K.zeros_like(rel_dev), rel_dev)
rel_dev = K.switch(~d0 & s0, K.sign(diff), rel_dev)
print(K.eval(rel_dev))
# [ 0. 0.25 2. -1. 0.6 0. ]
gd, gs = K.gradients(rel_dev, (diff, sum))
print(K.eval(gd))
# [0.5 0.25 1. 0. 0.2 0. ]
print(K.eval(gs))
# [-0. -0.0625 -2. 0. -0.12 0. ]
您可以使用 tensorflow 的where
function 对张量做您想做的事情。
是的,Tensorflow 的哪里是您要查找的内容,如果您将张量转换为 nparray 对 nparray 执行所有操作,那么您可以使用 tensor.numpy() 这将返回张量的 numpy 数组。 您可以使用“tf.convert_to_tensor”API 将 numpy 数组返回到张量。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.