简体   繁体   中英

tensorflow, How to index the true value of a tensor?

I have a 1-d tensor like:

[false, false, true, false, true, false]

How to find the index of all the true value?

My solution is:

  1. Turn it to 1 and 0 value
  2. Use argmax API to find one index and then set the one to false / 0 .
  3. Use argmax again to find the next true / 1 value

But this solution is not that good.

import tensorflow as tf

a = tf.constant([False,False,True,False,True],dtype=tf.bool)
b = tf.where(a)
sess = tf.Session()
print(sess.run(b))

Is this what you're looking for? [k for k, value in enumerate(tensor) if value]

In [1]: import tensorflow as tf                                                                                                                                                                                      

In [2]: a = tf.constant([False, False, True, True])

In [3]: a_n = [tf.cond(tf.equal(v, tf.constant(True)), lambda: tf.constant(k), lambda: tf.constant(-1)) for k, v in enumerate(tf.unstack(a))]                                                                        

In [4]: sess = tf.Session()

In [5]: sess.run(a_n)                                                                                                                                                                                                
Out[5]: [-1, -1, 2, 3]

Hope this helps...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM