简体   繁体   中英

Extract array from array without loop in python

I am trying to extract part of an array from an array.

Let's assume I have an array array1 with shape (M, N, P) . For my specific case, M = 10 , N = 5 , P = 2000 . I have another array, array2 of shape (M, N, 1) , which contains the starting points of the interesting data in array1 along the last axis. I want to extract 50 points of this data starting with the indices given by array2 , kind of like this:

array1[:, :, array2:array2 + 50] 

I would expect a result of shape (M, N, 50) . Unfortunatly I get the Error:

TypeError: only integer scalar arrays can be converted to a scalar index

Sure I could also get the result by looping through the array, but I feel that there must be a smarter way, because I needed this quite often.

Since your indices in each location are not aligned, you can create a mask or fancy index that extracts the desired elements. Since the extracted values are going to be flat, you will then have to reshape them.

Here is how you can create a mask:

K = 50
mask = np.zeros((M, N, P + 1), dtype=np.int8)
np.put_along_axis(mask, array2, 1, axis=-1)
np.put_along_axis(mask, array2 + K, -1, axis=-1)
mask.cumsum(axis=-1, out=mask)
mask = mask[..., :-1].view(bool)

We used the fact that np.int8 and np.bool_ have the same itemsize, and np.cumsum to propagate the initial mask position to the final in each axis.

The rest is fairly easy:

array3 = array1[mask].reshape(M, N, K)

You can avoid the extra element when constructing the mask by bypassing np.put_along_axis and using direct indexing with clipping where appropriate:

mask = np.zeros_like(array1, dtype=np.int8)
r = np.tile(np.arange(M)[:, None, None], (1, N, 1))
c = np.tile(np.arange(N)[None, :, None], (M, 1, 1))
clip_mask = array2 + K < P
mask[r, c, array2] = 1
mask[r[clip_mask], c[clip_mask], array2[clip_mask] + K] = -1
mask = np.cumsum(mask, axis=-1, out=mask).view(bool)

You can build a mask using a comparison of the values in array2 with an index range of the last dimension:

For example:

import numpy as np
    
M,N,P,k = 4,2,15,3   # yours would be 10,5,2000,50

A1 = np.arange(M*N*P).reshape((M,N,P))
A2 = np.arange(M*N).reshape((M,N,1)) + 1

rP = np.arange(P)[None,None,:]
A3 = A1[(rP>=A2)&(rP<A2+k)].reshape((M,N,k))

Input:

print(A1)

[[[  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14]
  [ 15  16  17  18  19  20  21  22  23  24  25  26  27  28  29]]

 [[ 30  31  32  33  34  35  36  37  38  39  40  41  42  43  44]
  [ 45  46  47  48  49  50  51  52  53  54  55  56  57  58  59]]

 [[ 60  61  62  63  64  65  66  67  68  69  70  71  72  73  74]
  [ 75  76  77  78  79  80  81  82  83  84  85  86  87  88  89]]

 [[ 90  91  92  93  94  95  96  97  98  99 100 101 102 103 104]
  [105 106 107 108 109 110 111 112 113 114 115 116 117 118 119]]]

print(A2)

[[[1]
  [2]]

 [[3]
  [4]]

 [[5]
  [6]]

 [[7]
  [8]]]

Output:

print(A3)

[[[  1   2   3]
  [ 17  18  19]]

 [[ 33  34  35]
  [ 49  50  51]]

 [[ 65  66  67]
  [ 81  82  83]]

 [[ 97  98  99]
  [113 114 115]]]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM