简体   繁体   中英

Apply argsort per row in array skipping certain elements based on threshold - NumPy / Python

I would like to apply a sort operation, row per row, only keeping values above a given threshold.

For this, I see I can use a masked array to apply the threshold. However, argsort keeps considering masked values (below the threshold) and replace them with a fill_value .

However, I simply don't want any result if the value has been replaced with a NaN.

a = np.array([[0.522235,0.128270,0.708973],
              [0.994557,0.844426,0.366608],
              [0.986669,0.143659,0.395891],
              [0.291339,0.421843,0.278869],
              [0.250303,0.861475,0.904534],
              [0.973436,0.360466,0.751913]])

threshold = 0.5
m_a = np.ma.masked_less_equal(a, threshold)
argsorted = m_a.argsort(-1)

This gives me:

array([[0, 2, 1],
       [1, 0, 2],
       [0, 1, 2],
       [0, 1, 2],
       [1, 2, 0],
       [2, 0, 1]])

But I would like to get:

array([[0,   NaN,   1],
       [1,     0, NaN],
       [0,   NaN, NaN],
       [NaN, NaN, NaN],
       [NaN,   0,   1],
       [  1, NaN,   0]])

Any idea to get to this result?

Thanks for your help, Bests,

We can add one more argsort for an easier way to get to our desired output -

sidx = argsorted.argsort(1)
mask = sidx >= (a.shape[1]-m_a.mask.sum(1,keepdims=True))
out = np.where(mask,np.nan,sidx)

We can also start from scratch to avoid masked-arrays -

def thresholded_argsort(a, threshold):
    m = a<threshold
    ac = a.copy()
    ac[m] = ac.max()+1
    sidx = ac.argsort(1).argsort(1)
    mask = sidx>=(ac.shape[1]-m.sum(1,keepdims=True))
    return np.where(mask,np.nan,sidx)

Sample run -

In [46]: a
Out[46]: 
array([[0.522235, 0.12827 , 0.708973],
       [0.994557, 0.844426, 0.366608],
       [0.986669, 0.143659, 0.395891],
       [0.291339, 0.421843, 0.278869],
       [0.250303, 0.861475, 0.904534],
       [0.973436, 0.360466, 0.751913]])

In [47]: thresholded_argsort(a, threshold=0.5)
Out[47]: 
array([[ 0., nan,  1.],
       [ 1.,  0., nan],
       [ 0., nan, nan],
       [nan, nan, nan],
       [nan,  0.,  1.],
       [ 1., nan,  0.]])

Note: We can avoid the additional argsort with array-assignment for performance using argsort_unique . So, for 2D arrays along second axis, it would be -

def argsort_unique2D(idx):
    m,n = idx.shape
    idx_out = np.empty((m,n),dtype=int)
    np.put_along_axis(idx_out, idx, np.arange(n), axis=1)
    return idx_out

So, argsorted.argsort(1) could be replaced by argsort_unique2D(argsorted) , while ac.argsort(1).argsort(1) with argsort_unique2D(ac.argsort(1)) in the earlier posted solutions.

If I understand correctly you dont want to consider NaN for for the sorting. In that case, I am not sure about the logic behind your expected result. You can try the following code. I believe this is what you are looking for:-

import numpy as np
a = np.array([[0.522235,0.128270,0.708973],
              [0.994557,0.844426,0.366608],
              [0.986669,0.143659,0.395891],
              [0.291339,0.421843,0.278869],
              [0.250303,0.861475,0.904534],
              [0.973436,0.360466,0.751913]])

threshold = 0.5
m_a = np.ma.masked_less_equal(a, threshold).filled(np.nan)
result = np.where(
        np.isnan(m_a),
        np.nan, m_a.argsort(-1)
    )
result

It should give you the following result:-

array([[ 0., nan,  1.],
       [ 1.,  0., nan],
       [ 0., nan, nan],
       [nan, nan, nan],
       [nan,  2.,  0.],
       [ 2., nan,  1.]])

Hope this helps!!

a = np.array([[0.522235,0.128270,0.708973],
              [0.994557,0.844426,0.366608],
              [0.986669,0.143659,0.395891],
              [0.291339,0.421843,0.278869],
              [0.250303,0.861475,0.904534],
              [0.973436,0.360466,0.751913]])

threshold = .5


def tri(ligne):
    s = sorted(ligne, key=lambda x: x < threshold and float('inf') or x)
    nv_liste = [s.index(v) for v in ligne]
    for i in range(len(ligne)):
        if ligne[i] < threshold:
            nv_liste[i] = np.nan
    return nv_liste

np.apply_along_axis(tri, 1, a)

gives you:

array([[ 0., nan,  1.],
       [ 1.,  0., nan],
       [ 0., nan, nan],
       [nan, nan, nan],
       [nan,  0.,  1.],
       [ 1., nan,  0.]])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM