简体   繁体   中英

Remove rows from numpy array based on presence/absence in other arrays

I have 3 different numpy arrays, but they all start with two columns which contain the day of year and the time. For example:

   dyn = [[  83   12   7.10555687e-01 ...,   6.99242766e-01   6.868761e-01]
         [  83   13   8.28091972e-01 ...,   8.33734118e-01   8.47266838e-01]
         [  83   14   8.79437354e-01 ...,   8.73598144e-01   8.57156213e-01]
         [  161   23   3.28109488e-01 ...,   2.83043689e-01  2.59775391e-01]
         [  162   0    2.23502046e-01 ...,   1.96972086e-01  1.65565263e-01]
         [  162   1   2.51653976e-01 ...,   2.17209188e-01   1.42133495e-1]]

   us = [[  133   18   3.00483815e+02 ...,   1.94277561e+00   2.8168959e+00]
        [  133   19   2.98832620e+02 ...,   2.42506475e+00   2.99730800e+00]
        [  133   20   2.96706105e+02 ...,   3.16851622e+00   4.41187088e+00]
        [  161   23   2.88336560e+02 ...,   3.44864070e-01   3.85055635e-01]
        [  162   0    2.87593240e+02 ...,   2.93002410e-01   2.67112490e-01]
        [  162   2    2.86992180e+02 ...,   7.08996730e-02   2.6403210e-01]]

I need to be able to remove any rows where specific date and time isn't present in all 3 arrays. In other words, so I'm left with 3 arrays where the first 2 columns are identical in each of the 3 arrays.

So the resulting smaller arrays would be:

dyn= [[  161   23   3.28109488e-01 ...,   2.83043689e-01  2.59775391e-01]
     [  162   0    2.23502046e-01 ...,   1.96972086e-01  1.65565263e-01]]

us= [[  161   23   2.88336560e+02 ...,   3.44864070e-01   3.85055635e-01]
    [  162   0    2.87593240e+02 ...,   2.93002410e-01   2.67112490e-01]]

(But then also limited by what's in the third array)

I've tried using sort/zip but not sure that it should be applied to 2D array like that:

X= dyn
Y = us
xsorted=[x for (y,x) in sorted(zip(Y[:,1],X[:,1]), key=lambda pair: pair[0])]

And also a loop but that only works when the same times/days are in the same position within the array, which isn't helpful

for i in range(100):
     dyn_small=dyn[dyn[:,0]==us[i,0]]

Assuming A , B and C as the input arrays, here's a vectorized approach making heavy usage of broadcasting -

# Get masks comparing all rows of A with B and then B with C
M1 = (A[:,None,:2] == B[:,:2])
M2 = (B[:,None,:2] == C[:,:2])

# Get a joint 3D mask of those two masks and get the indices of matches.
# These indices (I,J,K) of the 3D mask basically tells us the row numbers 
# correspondng to each of the input arrays that are present in all of them.
# Thus, in (I,J,K), I would be the matching row number in A, J in B & K in C.
I,J,K = np.where((M1[:,:,None,:] & M2).all(3))

# Finally, select rows of A, B and C with I, J and K respectively
A_new = A[I]
B_new = B[J]
C_new = C[K]

Sample run -

1) Inputs :

In [116]: A
Out[116]: 
array([[ 83,  12, 443],
       [ 83,  13, 565],
       [ 83,  14, 342],
       [161,  23, 431],
       [162,   0, 113],
       [162,   1, 313]])

In [117]: B
Out[117]: 
array([[161,  23, 999],
       [  5,   1,  13],
       [ 83,  12,  15],
       [162,   0,  12],
       [  4,   3,  11]])

In [118]: C
Out[118]: 
array([[ 11,  23, 143],
       [162,   0, 113],
       [161,  23, 545]])

2) Run solution code to get matching row IDs and thus extract the rows :

In [119]: M1 = (A[:,None,:2] == B[:,:2])
     ...: M2 = (B[:,None,:2] == C[:,:2])
     ...: 

In [120]: I,J,K = np.where((M1[:,:,None,:] & M2).all(3))

In [121]: A[I]
Out[121]: 
array([[161,  23, 431],
       [162,   0, 113]])

In [122]: B[J]
Out[122]: 
array([[161,  23, 999],
       [162,   0,  12]])

In [123]: C[K]
Out[123]: 
array([[161,  23, 545],
       [162,   0, 113]])

The numpy_indexed package (disclaimer: I am its author) contains functionality to solve such problems in an elegant and efficient/vectorized manner:

import numpy as np
import numpy_indexed as npi

dyn = np.array(dyn)
us = np.array(us)

dyn_index = npi.as_index(dyn[:, :2])
us_index = npi.as_index(us[:, :2])

common = npi.intersection(dyn_index, us_index)
print(common)
print(dyn[npi.contains(common, dyn_index)])
print(us[npi.contains(common, us_index)])

Note that the performance NlogN worst case; and linear insofar as the arguments to as_index are already in sorted order. By contrast, the currently accepted answer is quadratic in input size.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM