简体   繁体   English

如何使用 TensorFlow 中的指定索引访问 3D 张量的元素?

[英]How can I access elements of a 3D tensor using specified indices in TensorFlow?

I'm trying to get the rows of a 3D tensor in a specific order of indices.我正在尝试以特定的索引顺序获取 3D 张量的行。 Here are the inputs:以下是输入:

import tensorflow as tf

matrix = tf.constant([
    [[0, 1], [2, 3], [4, 5], [6, 7]], 
    [[8, 9], [10, 11], [12, 13], [14, 15]], 
    [[16, 17], [18, 19], [20, 21], [22, 23]], 
    [[24, 25], [26, 27], [28, 29], [30, 31]], 
    [[32, 33], [34, 35], [36, 37], [38, 39]]
])

indx = tf.constant([[3,2,1,0], [0,1,2,3], [1,0,3,2], [0,3,1,2], [1,2,3,0]])

# required output tensor:
[[[6, 7], [4, 5], [2, 3], [0, 1]],
 [[8, 9], [10, 11], [12, 13], [14, 15]],
 [[18, 19], [16, 17], [22, 23], [20, 21]],
 [[24, 25], [30, 31], [26, 27], [28, 29]],
 [[34, 35], [36, 37], [38, 39], [32, 33]]]

I'm struggling with tf.gather_nd() .我正在努力tf.gather_nd() Any suggestion?有什么建议吗? I can see it's happening here but I'm not sure how to apply on entire matrix without using for loop or tf.map_fn我可以看到它在这里发生,但我不确定如何在不使用for循环或tf.map_fn的情况下应用于整个矩阵

print(tf.gather_nd(matrix[0], tf.expand_dims(indx, -1)[0]).numpy().tolist())
print(tf.gather_nd(matrix[1], tf.expand_dims(indx, -1)[1]).numpy().tolist())
print(tf.gather_nd(matrix[2], tf.expand_dims(indx, -1)[2]).numpy().tolist())
print(tf.gather_nd(matrix[3], tf.expand_dims(indx, -1)[3]).numpy().tolist())
print(tf.gather_nd(matrix[4], tf.expand_dims(indx, -1)[4]).numpy().tolist())

"""
[[6, 7], [4, 5], [2, 3], [0, 1]]
[[8, 9], [10, 11], [12, 13], [14, 15]]
[[18, 19], [16, 17], [22, 23], [20, 21]]
[[24, 25], [30, 31], [26, 27], [28, 29]]
[[34, 35], [36, 37], [38, 39], [32, 33]]
"""

EDIT: I asked a similar question with respect to numpy.编辑:我问了一个关于 numpy 的类似问题。 A clever indexing answer does solves the numpy version, but it's hard to apply it on Tensors.一个聪明的索引答案确实解决了 numpy 版本,但很难将它应用于张量。 Feel free to take a look at the accepted answer here: How can I get elements from 3D matrix using specified indices in numpy?随意看看这里接受的答案: 如何使用 numpy 中的指定索引从 3D 矩阵中获取元素?

Duh, that was stupid;呃,那太愚蠢了; There is already a very great function available that works on multi-dimensional array in tensorflow;已经有一个非常棒的 function 可用于 tensorflow 中的多维数组; tf.gather() Check out the batch_dims argument for more information. tf.gather()查看batch_dims参数以获取更多信息。

>> tf.gather(matrix, indx, batch_dims=1).numpy().tolist()
[[[6, 7], [4, 5], [2, 3], [0, 1]],
 [[8, 9], [10, 11], [12, 13], [14, 15]],
 [[18, 19], [16, 17], [22, 23], [20, 21]],
 [[24, 25], [30, 31], [26, 27], [28, 29]],
 [[34, 35], [36, 37], [38, 39], [32, 33]]]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM