[英]Gathering entries in a matrix based on a matrix of column indices (tensorflow/numpy)
一个小例子来证明我需要什么
我有一个关于在 tensorflow 中收集的问题。 假设我有一个值的张量(出于某种原因我关心):
test1 = tf.round(5*tf.random.uniform(shape=(2,3)))
这给了我这个输出:
<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 2.],
[4., 5., 0.]], dtype=float32)>
而且我还有一个索引列索引的张量,我想在每一行上挑选出来:
test_ind = tf.constant([[0,1,0,0,1],
[0,1,1,1,0]], dtype=tf.int64)
我想收集这些,以便从第一行(第 0 行)中挑选出第 0、1、0、0、1 列中的项目,第二行也一样。
所以这个例子的输出应该是:
<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[1., 1., 1., 1., 1.],
[4., 5., 5., 5., 4.]], dtype=float32)>
我的尝试
所以我想出了一个一般的方法,我写了下面的函数gather_matrix_indices(),它将接受一个值的张量和一个索引的张量,并完全按照我上面指定的方式进行操作。
def gather_matrix_indices(input_arr, index_arr):
row, _ = input_arr.shape
li = []
for i in range(row):
li.append(tf.expand_dims(tf.gather(params=input_arr[i], indices=index_arr[i]), axis=0))
return tf.concat(li, axis=0)
我的问题
我只是想知道,有没有办法只使用 tensorflow 或 numpy 方法来做到这一点? 我能想出的唯一解决方案是编写自己的函数,该函数遍历每一行并收集该行中所有列的索引。 我还没有遇到运行时问题,但我更愿意尽可能使用内置的 tensorflow 或 numpy 方法。 我之前也尝试过 tf.gather,但我不知道这种特殊情况是否可以使用 tf.gather 和 tf.gather_nd 的任何组合。 如果有人有建议,我将不胜感激。
您可以为此使用gather_nd()
。 让它工作看起来有点棘手。 让我试着用形状来解释这一点。
我们得到了test1 -> [2, 3]
和test_ind_col_ind -> [2, 5]
。 test_ind_col_ind
只有列索引,但您还需要行索引才能使用gather_nd()
。 要将gather_nd()
与[2,3]
张量一起使用,我们需要创建一个test_ind -> [2, 5, 2]
大小的张量。 这个新的test_ind
最里面的维度对应于您要从test1
索引的各个索引。 这里我们有格式为(<row index>, <col index>)
的inner most dimension = 2
。 换句话说,看看test_ind
的形状,
[ 2 , 5 , 2 ]
| |
V |
(2,5) | <- The size of the final tensor
V
(2,) <- The full index to a scalar in your input tensor
import tensorflow as tf
test1 = tf.round(5*tf.random.uniform(shape=(2,3)))
print(test1)
test_ind_col_ind = tf.constant([[0,1,0,0,1],
[0,1,1,1,0]], dtype=tf.int64)[:, :, tf.newaxis]
test_ind_row_ind = tf.repeat(tf.range(2, dtype=tf.int64)[:, tf.newaxis, tf.newaxis], 5, axis=1)
test_ind = tf.concat([test_ind_format, test_ind], axis=-1)
res = tf.gather_nd(indices=test_ind, params=test1)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.