简体   繁体   English

使用 .take() 索引多维数组

[英]Use .take() to index multidimensional array

I have a multidimensional array of shape (n,x,y).我有一个形状(n,x,y)的多维数组。 For this example can use this array对于这个例子可以使用这个数组

A = array([[[ 0,  1,  2],
            [ 3,  4,  5],
            [ 6,  7,  8],
            [ 9, 10, 11]],

           [[12, 13, 14],
            [15, 16, 17],
            [18, 19, 20],
            [21, 22, 23]],

           [[24, 25, 26],
            [27, 28, 29],
            [30, 31, 32],
            [33, 34, 35]]])

I then have another multidimensional array that has index values that I want to use on the original array, A. This has shape (z,2) and the values represent row values index's然后我有另一个多维数组,它具有我想在原始数组 A 上使用的索引值。它具​​有形状 (z,2) 并且这些值表示行值索引的

Row_values = array([[0,1],
                    [0,2],
                    [1,2],
                    [1,3]])

So I want to use all the index values in row_values to apply to each of the three arrays in A so I end up with a final array of shape (12,2,3)所以我想使用 row_values 中的所有索引值来应用 A 中的三个数组中的每一个,所以我最终得到一个形状为 (12,2,3) 的最终数组

Result = ([[[0,1,2],
            [3,4,5]],
           [[0,1,2],
            [6,7,8]],
           [[3,4,5],
            [6,7,8]]
           [[3,4,5],
            [9,10,11],
           [[12,13,14],
            [15,16,17]],
           [[12,13,14],
            [18,19,20]],
           [[15,16,17],
            [18,19,20]],
           [[15,16,17],
            [21,22,23]],
           [[24,25,26],
            [27,28,29]],
           [[24,25,26],
            [30,31,32]],
           [[27,28,29],
            [30,31,32]],
           [[27,28,29],
            [33,34,35]]]

I have tried using np.take() but haven't been able to make it work.我曾尝试使用 np.take() 但未能使其正常工作。 Not sure if there's another numpy function that is easier to use不确定是否有另一个更容易使用的 numpy 函数

We can advantage of NumPy 's advanced indexing and using np.repeat and np.tile along with it.我们可以利用NumPy的高级索引并将np.repeatnp.tile与它一起使用。

cidx = np.tile(Row_values, (A.shape[0], 1))
ridx = np.repeat(np.arange(A.shape[0]), Row_values.shape[0])

out = A[ridx[:, None], cidx]
# out.shape -> (12, 2, 3)

Output:输出:

array([[[ 0,  1,  2],
        [ 3,  4,  5]],

       [[ 0,  1,  2],
        [ 6,  7,  8]],

       [[ 3,  4,  5],
        [ 6,  7,  8]],

       [[ 3,  4,  5],
        [ 9, 10, 11]],

       [[12, 13, 14],
        [15, 16, 17]],

       [[12, 13, 14],
        [18, 19, 20]],

       [[15, 16, 17],
        [18, 19, 20]],

       [[15, 16, 17],
        [21, 22, 23]],

       [[24, 25, 26],
        [27, 28, 29]],

       [[24, 25, 26],
        [30, 31, 32]],

       [[27, 28, 29],
        [30, 31, 32]],

       [[27, 28, 29],
        [33, 34, 35]]])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM