[英]Better way to access individual elements in a tensor
I am trying to access the elements of a tensor a
, with the indexes defined in tensor b
. 我正在尝试使用张量
b
定义的索引访问张量a
的元素。
a=tf.constant([[1,2,3,4],[5,6,7,8]])
b=tf.constant([0,1,1,0])
I want the output to be 我希望输出是
out = [1 6 7 4]
What I have tried: 我尝试过的
out=[]
for i in range(a.shape[1]):
out.append(a[b[i],i])
out=tf.stack(out) #[1 6 7 4]
This is giving the correct output, but I'm looking for a better and a compact way to do it. 这给出了正确的输出,但是我正在寻找一种更好,更紧凑的方式来实现。
Also my logic doesnt work when the shape of a
is something like (2,None)
since I cannot iterate with range(a.shape[1])
, it would help me if the answer included this case too 同时我的逻辑不工作时的造型
a
是一样的东西(2,None)
,因为我不能重复range(a.shape[1])
它将如果答案包含在此情况下,也帮助我
Thanks 谢谢
You can use tf.one_hot()
and tf.boolean_mask()
. 您可以使用
tf.one_hot()
和tf.boolean_mask()
。
import tensorflow as tf
import numpy as np
a_tf = tf.placeholder(shape=(2,None),dtype=tf.int32)
b_tf = tf.placeholder(shape=(None,),dtype=tf.int32)
index = tf.one_hot(b_tf,a_tf.shape[0])
out = tf.boolean_mask(tf.transpose(a_tf),index)
a=np.array([[1,2,3,4],[5,6,7,8]])
b=np.array([0,1,1,0])
with tf.Session() as sess:
print(sess.run(out,feed_dict={a_tf:a,b_tf:b}))
# print
[1 6 7 4]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.