简体   繁体   中英

tf.where() on broadcast arrays

I have two arrays (x is 1D and y is 2D). I have calculated the array "diff" which is basically the broadcast difference (xy[:,None]). I would like to replace all zeroes in the array "diff" with a large value (say 10000). This operation is trivial in numpy as you see below:

x=np.array([1.0,1.0,1.0])
y=np.array([[1.0,1.0,1.0],[0.0,0.0,0.0]])
diff = x - y[:, None]
diff = np.where(diff==0.0,10000,diff)

However, I am not able to reproduce the same behavior in Tensorflow. I tried the following code block.

x = tf.placeholder(tf.float32) 
y = tf.placeholder(tf.float32)
diff = x - y[:,None]
diff_zero = tf.cast(tf.zeros_like(diff),tf.float32)
diff_big = tf.cast(tf.ones_like(diff)*100000,tf.float32)

diff = tf.where(diff==diff_zero, diff_big, diff)

sess = tf.Session()
diff_array = sess.run(diff, feed_dict={x: [1.0,1.0,1.0], y: [[1.0,1.0,1.0],[0.0,0.0,0.0]]})

Any work-around would be appreciated.

I figured out how to do it. I had to use tf.equal() instead of "==". The following lines of code did the job just like numpy.

x = tf.placeholder(tf.float32) 
y = tf.placeholder(tf.float32)
diff = x - y[:,None]

diff_zero = tf.cast(tf.zeros_like(diff),tf.float32)
diff_big = tf.cast(tf.ones_like(diff)*100000,tf.float32)

condition = tf.equal(diff_zero, diff)
diff = tf.where(condition, diff_big, diff)
sess = tf.Session()

diff_array = sess.run(diff, feed_dict={x: [1.0,1.0,1.0], y: [[1.0,1.0,1.0],[0.0,0.0,0.0]]})

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM