繁体   English   中英

如何在 keras 的条件下更换张量的某些部分?

[英]How to replace certain parts of a tensor on the condition in keras?

我想在具有 tensorflow 后端的 keras 张量上执行类似于 np.where 的操作。 这意味着假设我有两个张量:diff 和 sum。 我将这些向量划分为:

rel_dev = diff / sum

对于 np.arrays 我会写:

rel_dev = np.where((diff == 0.0) & (sum == 0.0), 0.0, rel_dev)
rel_dev = np.where((diff != 0.0) & (sum == 0.0), np.sign(diff), rel_dev)

也就是说,例如,如果我在 diff 和 sum 中都为零,我希望我不会得到 np.Inf,但将 rel_dev 设置为零。 现在在带有张量的 keras 中它不起作用。 我已经尝试过 K.switch、K.set_value 等。据我了解,它适用于整个张量,但不适用于单独的部分,对吗? 它可以在不设置这些条件的情况下工作,但我不知道实际发生在哪里。 我还没有成功调试它。

您能否告诉我如何在 Keras 中为 rel_dev 编写两个条件?

您可以像这样在 Keras 中执行此操作:

import keras.backend as K

diff = K.constant([0, 1, 2, -2, 3, 0])
sum = K.constant([2, 4, 1, 0, 5, 0])
rel_dev = diff / sum
d0 = K.equal(diff, 0)
s0 = K.equal(sum, 0)
rel_dev = K.switch(d0 & s0, K.zeros_like(rel_dev), rel_dev)
rel_dev = K.switch(~d0 & s0, K.sign(diff), rel_dev)
print(K.eval(rel_dev))
# [ 0.    0.25  2.   -1.    0.6   0.  ]

编辑:上面的公式有一个隐蔽的问题,即即使结果是正确的, nan值也会通过梯度传播回来(即因为除以零得到infnan ,并且将infnan乘以零得到nan ) . 事实上,如果你检查梯度:

gd, gs = K.gradients(rel_dev, (diff, sum))
print(K.eval(gd))
# [0.5  0.25 1.    nan 0.2   nan]
print(K.eval(gs))
# [-0.     -0.0625 -2.         nan -0.12       nan]

您可以用来避免这种情况的技巧是以不影响结果但阻止nan值的方式更改除法中的sum ,例如:

import keras.backend as K

diff = K.constant([0, 1, 2, -2, 3, 0])
sum = K.constant([2, 4, 1, 0, 5, 0])
d0 = K.equal(diff, 0)
s0 = K.equal(sum, 0)
# sum zeros are replaced by ones on division
rel_dev = diff / K.switch(s0, K.ones_like(sum), sum)
rel_dev = K.switch(d0 & s0, K.zeros_like(rel_dev), rel_dev)
rel_dev = K.switch(~d0 & s0, K.sign(diff), rel_dev)
print(K.eval(rel_dev))
# [ 0.    0.25  2.   -1.    0.6   0.  ]
gd, gs = K.gradients(rel_dev, (diff, sum))
print(K.eval(gd))
# [0.5  0.25 1.   0.   0.2  0.  ]
print(K.eval(gs))
# [-0.     -0.0625 -2.      0.     -0.12    0.    ]

您可以使用 tensorflow 的where function 对张量做您想做的事情。

是的,Tensorflow 的哪里是您要查找的内容,如果您将张量转换为 nparray 对 nparray 执行所有操作,那么您可以使用 tensor.numpy() 这将返回张量的 numpy 数组。 您可以使用“tf.convert_to_tensor”API 将 numpy 数组返回到张量。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM