繁体   English   中英

tf.where() 在操作张量时表现不佳

[英]tf.where() not behaving as expected for manipulating tensors

我尝试了以下代码:

a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19))

不会产生与以下相同的结果:

a = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19)

这里 x 是二进制变量 0 或 1。b 是介于 0 和 1 之间的实数。

有什么我想念的吗?
我比较 2 个答案的方式是 tf.reduce_sum(a)

找到的解决方案:对于 x = 0 或 x = 1,这 2 个确实等价。我使用的数据是一个 2D 张量,其中有些位不是 0 或 1。这是通过tf.unique(tf.reshape(x, (-1,))

代码示例:

# When x = 0.0 

x = 0.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # 0 + (1-x)*tf.math.log(1 - b + 1e-19) = -0.6931472

# When x = 1.0 

x = 1.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from x*tf.math.log(b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) + 0 = -0.6931472


# When x = 0.4 

x = 0.4
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.41588834 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) +  (1-x)*tf.math.log(1 - b + 1e-19) = -0.27725 + -0.41588 = -0.6931471824645996.

您问题中提到的两种代码都会产生相同结果的唯一情况是 x = 1 或 0。

tf.reduce_sum(a)

这里,a 是一个标量,所以这不会改变 a 的值。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM