简体   繁体   中英

Get index for argmin of 2d numpy array

I have a 2D numpy array of distances:

a = np.array([[2.0, 12.1, 99.2], 
              [1.0, 1.1, 1.2], 
              [1.04, 1.05, 1.5], 
              [4.1, 4.2, 0.2], 
              [10.0, 11.0, 12.0], 
              [3.9, 4.9, 4.99] 
             ])

I need a function that assesses each row and returns the column index for the column that has the smallest value. Of course, this can be done trivially by doing:

np.argmin(a, axis=1) 

which yields:

[0, 0, 0, 2, 0, 0]

However, I have a few constraints:

  1. The argmin evaluation should only consider distances below a value of 5.0. If none of the distances within a row are below 5.0 then return '-1' as the index
  2. The list of indices returned for all rows need to be unique (ie if two or more rows end up with the same column index then the row with the smaller distance to the given column index is given priority and all other rows must return a different column index). I'm guessing that this will make the problem an iterative one since if one of the rows gets bumped then it could subsequently clash with another row with the same column index.
  3. Any unassigned rows should return '-1'

Thus, the final output should look like:

[-1, 0, 1, 2, -1, -1]

One starting point would be to:

  1. perform an argsort
  2. assign unique column indices to rows
  3. remove assigned column indices from each row
  4. resolve tie-breaks
  5. repeat step 2-4 until either all column indices are assigned

Is there a straightforward way to accomplish this in Python?

This loops over the number of columns, which I assume is smaller than the number of rows:

def find_smallest(a):
    i = np.argmin(a, 1)
    amin = a[np.arange(len(a)), i] # faster than a.min(1)?
    toobig = amin >=5
    i[toobig] = -1
    for u, c in zip(*np.unique(i, return_counts=True)):
        #u, c are the unique values and number of occurrences in `i`
        if c < 2:
            # no repeats of this index
            continue
        mask = i==u # the values in i that match u, which has repeats
        notclosest = np.where(mask)[0].tolist() # indices of the repeats
        notclosest.pop(np.argmin(amin[mask])) # the smallest a value is not a 'repeat', remove it from the list
        i[notclosest] = -1 # and mark all the repeats as -1
    return i

Note, I've used -1 instead of np.nan since an index array is int . Any reduction in the boolean indexing would help. I wanted to use one of the optional additional outputs from np.unique(i) but couldn't.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM