简体   繁体   中英

tf.where() not behaving as expected for manipulating tensors

I have tried the following code:

a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19))

does not produce the same results as:

a = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19)

Here x is a binary variable either 0 or 1. b is real valued between 0 and 1.

Is there something I am missing?
The way I compared the 2 answers was tf.reduce_sum(a)

Solution found: The 2 are indeed equivalent for x = 0 or x = 1. The data I used was a 2D tensor, which had some bits not 0 or 1. This was discovered via tf.unique(tf.reshape(x, (-1,))

Code sample:

# When x = 0.0 

x = 0.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # 0 + (1-x)*tf.math.log(1 - b + 1e-19) = -0.6931472

# When x = 1.0 

x = 1.0
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.6931472 from x*tf.math.log(b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) + 0 = -0.6931472


# When x = 0.4 

x = 0.4
b = 0.5
a = tf.where(tf.greater_equal(x,1.0),x*tf.math.log(b + 1e-19), (1-x)*tf.math.log(1 - b + 1e-19)) # -0.41588834 from (1-x)*tf.math.log(1 - b + 1e-19)
c = x*tf.math.log(b + 1e-19) + (1-x)*tf.math.log(1 - b + 1e-19) # x*tf.math.log(b + 1e-19) +  (1-x)*tf.math.log(1 - b + 1e-19) = -0.27725 + -0.41588 = -0.6931471824645996.

The only case where both codes mentioned in your question will produce the same results is when x = 1 or 0.

tf.reduce_sum(a)

Here, a is a scalar so this won't change the value of a.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM