简体   繁体   中英

Better way to access individual elements in a tensor

I am trying to access the elements of a tensor a , with the indexes defined in tensor b .

a=tf.constant([[1,2,3,4],[5,6,7,8]])
b=tf.constant([0,1,1,0])

I want the output to be

out = [1 6 7 4]

What I have tried:

out=[]
for i in range(a.shape[1]):
    out.append(a[b[i],i])

out=tf.stack(out) #[1 6 7 4]

This is giving the correct output, but I'm looking for a better and a compact way to do it.

Also my logic doesnt work when the shape of a is something like (2,None) since I cannot iterate with range(a.shape[1]) , it would help me if the answer included this case too

Thanks

You can use tf.one_hot() and tf.boolean_mask() .

import tensorflow as tf
import numpy as np

a_tf = tf.placeholder(shape=(2,None),dtype=tf.int32)
b_tf = tf.placeholder(shape=(None,),dtype=tf.int32)

index = tf.one_hot(b_tf,a_tf.shape[0])
out = tf.boolean_mask(tf.transpose(a_tf),index)

a=np.array([[1,2,3,4],[5,6,7,8]])
b=np.array([0,1,1,0])
with tf.Session() as sess:
    print(sess.run(out,feed_dict={a_tf:a,b_tf:b}))

# print
[1 6 7 4]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM