简体   繁体   中英

Numpy argsort while distinguishing values of 0

I have a very large array but here I will show a simplified case:

a = np.array([[3, 0, 5, 0], [8, 7, 6, 10], [5, 4, 0, 10]])
array([[ 3,  0,  5,  0],
       [ 8,  7,  6, 10],
       [ 5,  4,  0, 10]])

I want to argsort() the array but have a way to distinguish 0s. I tried to replace it with NaN :

a = np.array([[3, np.nan, 5, np.nan], [8, 7, 6, 10], [5, 4, np.nan, 10]])
a.argsort()
array([[0, 2, 1, 3],
       [2, 1, 0, 3],
       [1, 0, 3, 2]])

But the NaN s are still being sorted. Is there any way to make argsort give it a value of -1 or something. Or is there another option other than NaN to replace 0s? I tried math.inf with no success as well. Anybody has any ideas?

The purpose of doing this is that I have a cosine similarity matrix, and I want to exclude those instances where similarities are 0. I am using argsort() to get the highest similarities, which will give me the indices to another table with mappings to labels. If an array's entire similarity is 0 ([0,0,0]), then I want to ignore it. So if I can get argsort() to output it as [-1,-1,-1] after sorting, I can check to see if the entire array is -1 and exclude it.

EDIT:

So output should be:

array([[0, 2, -1, -1],
       [2, 1, 0, 3],
       [1, 0, 3, -1]])

So when using the last row to refer back to a: the smallest will be a[1], which is 4, followed by a[0], which is 5, then a[3], which is 10, and at last -1, which is the 0

If you mean "distinguish 0s" as the highest value or lowest values, I would suggest trying:

a[a==0]=(a.max()+1)

or:

a[a==0]=(a.min()-1)

You may want to use numpy.ma.array() like this

a = np.array([[3,4,5],[8,7,6],[5,4,0]])

mask this array with condition a==0 ,

a_mask = np.ma.array(a, mask=(a==0))
print(a_mask)
# output
masked_array(
  data=[[3, 4, 5],
        [8, 7, 6],
        [5, 4, --]],
  mask=[[False, False, False],
        [False, False, False],
        [False, False,  True]],
  fill_value=999999)
print(a_mask.mask)
# outputs
array([[False, False, False],
       [False, False, False],
       [False, False,  True]])

and you can use the mask attribute of masked_array to distinguish elements you want to label and fill in other values.

One way to achieve the task is to first generate a boolean mask checking for zero values (since you want to distinguish this in the array), then sort it and then use the boolean mask to set the desired values (eg, -1)

# your unmodified input array
In [294]: a
Out[294]: 
array([[3, 4, 5],
       [8, 7, 6],
       [5, 4, 0]])

# boolean mask checking for zero
In [295]: zero_bool_mask = a == 0

In [296]: zero_bool_mask
Out[296]: 
array([[False, False, False],
       [False, False, False],
       [False, False,  True]])

# usual argsort
In [297]: sorted_idxs = np.argsort(a)

In [298]: sorted_idxs
Out[298]: 
array([[0, 1, 2],
       [2, 1, 0],
       [2, 1, 0]])

# replace the indices of 0 with desired value (e.g., -1)
In [299]: sorted_idxs[zero_bool_mask] = -1

In [300]: sorted_idxs
Out[300]: 
array([[ 0,  1,  2],
       [ 2,  1,  0],
       [ 2,  1, -1]])

Following this, to account for the correct sorting indices after the substitution value (eg, -1), we have to perform this final step:

In [327]: sorted_idxs - (sorted_idxs == -1).sum(1)[:, None]
Out[327]: 
array([[ 0,  1,  2],
       [ 2,  1,  0],
       [ 1,  0, -2]])

So now the sorted_idxs with negative values are the locations where you had zero s in the original array.


Thus, we can have a custom function like so:

def argsort_excluding_zeros(arr, replacement_value):
    zero_bool_mask = arr == 0
    sorted_idxs = np.argsort(arr)
    sorted_idxs[zero_bool_mask] = replacement_value
    return sorted_idxs - (sorted_idxs == replacement_value).sum(1)[:, None]

# another array
In [339]: a
Out[339]: 
array([[0, 4, 5],
       [8, 7, 6],
       [5, 4, 0]])

# sample run
In [340]: argsort_excluding_zeros(a, replacement_value=-1)
Out[340]: 
array([[-2,  0,  1],
       [ 2,  1,  0],
       [ 1,  0, -2]])

Using @kmario23 and @ScienceSnake code, I came up with the solution:

a = np.array([[3, 0, 5, 0], [8, 7, 6, 10], [5, 4, 0, 10]])
b = np.where(a == 0, np.inf, a)  # Replace 0 -> inf to make them sorted last
s = b.copy() # make a copy of b to sort it
s.sort() 
mask = s == np.inf # create a mask to get inf locations after sorting
c = b.argsort()
d = np.where(mask, -1, c)  # Replace where the zeros were originally with -1

Out:
array([[ 0,  2, -1, -1],
       [ 2,  1,  0,  3],
       [ 1,  0,  3, -1]])

Not the most efficient solution because it is sorting twice.....

There might be a slightly more efficient alternative, but this works in pure numpy and is very transparent.

import numpy as np

a = np.array([[3, 0, 5, 0], [8, 7, 6, 10], [5, 4, 0, 10]])
b = np.where(a == 0, np.inf, a)  # Replace 0 -> inf to make them sorted last
c = b.argsort()
d = np.where(a == 0, -1, c)  # Replace where the zeros were originally with -1
print(d)

outputs

[[ 0 -1  1 -1]
 [ 2  1  0  3]
 [ 1  0 -1  2]]

To save memory, some of the in-between assignments can be skipped, but I left it this way for clarity.

*** EDIT ***

The OP has clarified exactly what output they want. This is my new solution which has only one sort.

a = np.array([[3, 0, 5, 0], [8, 7, 6, 10], [5, 4, 0, 10]])
b = np.where(a == 0, np.inf, a).argsort()


def remove_invalid_entries(row, num_valid):
    row[num_valid.pop():] = -1
    return row


num_valid = np.flip(np.count_nonzero(a, 1)).tolist()
b = np.apply_along_axis(remove_invalid_entries, 1, b, num_valid)

print(b)
> [[ 0  2 -1 -1]
   [ 2  1  0  3]
   [ 1  0  3 -1]]

The start is as before. Then, we go through the argsorted list row by row, and replace the last n elements by -1, where n is the number of 0's that are in the corresponding row of the original list. The fastest way of doing this is with np.apply_along_axis . Here, I counted all the zeros in each row of a, and turn it into a list (reversed order) so that I can use pop() to get the number of elements to keep in the current row of b being iterated over by np.apply_along_axis .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM