繁体   English   中英

基于列索引矩阵(tensorflow/numpy)收集矩阵中的条目

[英]Gathering entries in a matrix based on a matrix of column indices (tensorflow/numpy)

一个小例子来证明我需要什么

我有一个关于在 tensorflow 中收集的问题。 假设我有一个值的张量(出于某种原因我关心):

test1 = tf.round(5*tf.random.uniform(shape=(2,3)))

这给了我这个输出:

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 2.],
       [4., 5., 0.]], dtype=float32)>

而且我还有一个索引列索引的张量,我想在每一行上挑选出来:

test_ind = tf.constant([[0,1,0,0,1],
                        [0,1,1,1,0]], dtype=tf.int64)

我想收集这些,以便从第一行(第 0 行)中挑选出第 0、1、0、0、1 列中的项目,第二行也一样。

所以这个例子的输出应该是:

<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[1., 1., 1., 1., 1.],
       [4., 5., 5., 5., 4.]], dtype=float32)>

我的尝试

所以我想出了一个一般的方法,我写了下面的函数gather_matrix_indices(),它将接受一个值的张量和一个索引的张量,并完全按照我上面指定的方式进行操作。

def gather_matrix_indices(input_arr, index_arr):
    row, _ = input_arr.shape
    
    li = []
    
    for i in range(row):
        li.append(tf.expand_dims(tf.gather(params=input_arr[i], indices=index_arr[i]), axis=0))
        
    return tf.concat(li, axis=0)

我的问题

我只是想知道,有没有办法只使用 tensorflow 或 numpy 方法来做到这一点? 我能想出的唯一解决方案是编写自己的函数,该函数遍历每一行并收集该行中所有列的索引。 我还没有遇到运行时问题,但我更愿意尽可能使用内置的 tensorflow 或 numpy 方法。 我之前也尝试过 tf.gather,但我不知道这种特殊情况是否可以使用 tf.gather 和 tf.gather_nd 的任何组合。 如果有人有建议,我将不胜感激。

您可以为此使用gather_nd() 让它工作看起来有点棘手。 让我试着用形状来解释这一点。

我们得到了test1 -> [2, 3]test_ind_col_ind -> [2, 5] test_ind_col_ind只有列索引,但您还需要行索引才能使用gather_nd() 要将gather_nd()[2,3]张量一起使用,我们需要创建一个test_ind -> [2, 5, 2]大小的张量。 这个新的test_ind最里面的维度对应于您要从test1索引的各个索引。 这里我们有格式为(<row index>, <col index>)inner most dimension = 2 换句话说,看看test_ind的形状,

[ 2 , 5 , 2 ]
    |     |
    V     |
  (2,5)   |       <- The size of the final tensor   
          V
         (2,)     <- The full index to a scalar in your input tensor
import tensorflow as tf

test1 = tf.round(5*tf.random.uniform(shape=(2,3)))
print(test1)

test_ind_col_ind = tf.constant([[0,1,0,0,1],
                        [0,1,1,1,0]], dtype=tf.int64)[:, :, tf.newaxis]

test_ind_row_ind = tf.repeat(tf.range(2, dtype=tf.int64)[:, tf.newaxis, tf.newaxis], 5, axis=1)

test_ind = tf.concat([test_ind_format, test_ind], axis=-1)

res = tf.gather_nd(indices=test_ind, params=test1)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM