简体   繁体   中英

TensorFlow: Error with tf.where()

I am not sure why tf.where() does not work as planned. I want to use the values of a where yt is less that 5, otherwise use b .

tf.InteractiveSession()
yt = tf.constant([10,1,10])
a = tf.constant([1,2,3])
b = tf.constant([3,4,5])
tf.where(tf.less(yt,[5]), a, b).eval()

Gives the error

where() takes at most 2 arguments (3 given)

Can you tell me why I am getting this error? Is there any other way to do this?

The syntax for tf.where() was changed between TensorFlow 0.10 (when it took two arguments and returned two outputs ) and TensorFlow 0.12+ (it now takes three tensor arguments and returns a single output , replacing the former tf.select() ).

As Himaprasoon suggests , upgrading to the latest version of TensorFlow should fix your problem.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM