[英]How to use a tensor for indexing another tensor in tensorflow
I have a data
tensor of dimensios [BXNX 3]
, and I have an indices
tensor of dimensions [BXM]
. 我有dimensios
[BXNX 3]
的data
张量,并且我有维度[BXM]
的indices
张量。 I wish to extract a [BXMX 3]
tensor from the data
tensor using the indices
tensor. 我希望使用
indices
张量从data
张量中提取[BXMX 3]
张量。
I have this code that works : 我有这段代码可以工作:
new_data= []
for i in range(B):
new_data.append(tf.gather(data[i], indices[i]))
new_data= tf.stack(new_data)
However, I am sure it is not the right way to do this. 但是,我确信这不是正确的方法。 Does anyone know a better way?
有谁知道更好的方法? (I guess I should use
tf.gather_nd()
somehow but I couldn't figure out how) (我想我应该以某种方式使用
tf.gather_nd()
但我不知道怎么做)
I have seen several answers to similar questions here . 我在这里看到了类似问题的几个答案。 However I could not find the solution to my problem.
但是,我找不到解决问题的方法。
You can use tf.gather_nd()
with code like this: 您可以将
tf.gather_nd()
与以下代码一起使用:
import tensorflow as tf
# B = 3
# N = 4
# M = 2
# [B x N x 3]
data = tf.constant([
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]],
[[100, 101, 102], [103, 104, 105], [106, 107, 108], [109, 110, 111]],
[[200, 201, 202], [203, 204, 205], [206, 207, 208], [209, 210, 211]],
])
# [B x M]
indices = tf.constant([
[0, 2],
[1, 3],
[3, 2],
])
indices_shape = tf.shape(indices)
indices_help = tf.tile(tf.reshape(tf.range(indices_shape[0]), [indices_shape[0], 1]) ,[1, indices_shape[1]]);
indices_ext = tf.concat([tf.expand_dims(indices_help, 2), tf.expand_dims(indices, 2)], axis = 2)
new_data = tf.gather_nd(data, indices_ext)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('data')
print(sess.run(data))
print('\nindices')
print(sess.run(indices))
print('\nnew_data')
print(sess.run(new_data))
new_data
will be: new_data
将是:
[[[ 0 1 2]
[ 6 7 8]]
[[103 104 105]
[109 110 111]]
[[209 210 211]
[206 207 208]]]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.