简体   繁体   English

获取二维numpy数组的argmin的索引

[英]Get index for argmin of 2d numpy array

I have a 2D numpy array of distances: 我有一个二维的numpy距离数组:

a = np.array([[2.0, 12.1, 99.2], 
              [1.0, 1.1, 1.2], 
              [1.04, 1.05, 1.5], 
              [4.1, 4.2, 0.2], 
              [10.0, 11.0, 12.0], 
              [3.9, 4.9, 4.99] 
             ])

I need a function that assesses each row and returns the column index for the column that has the smallest value. 我需要一个函数来评估每一行并返回具有最小值的列的列索引。 Of course, this can be done trivially by doing: 当然,可以通过执行以下操作来轻松完成此操作:

np.argmin(a, axis=1) 

which yields: 产生:

[0, 0, 0, 2, 0, 0]

However, I have a few constraints: 但是,我有一些限制:

  1. The argmin evaluation should only consider distances below a value of 5.0. argmin评估应仅考虑小于5.0的距离。 If none of the distances within a row are below 5.0 then return '-1' as the index 如果一行中的距离均不小于5.0,则返回“ -1”作为索引
  2. The list of indices returned for all rows need to be unique (ie if two or more rows end up with the same column index then the row with the smaller distance to the given column index is given priority and all other rows must return a different column index). 为所有行返回的索引列表必须是唯一的(即,如果两个或更多行以相同的列索引结尾,那么与给定列索引的距离较小的行将被赋予优先级,而所有其他行必须返回不同的列指数)。 I'm guessing that this will make the problem an iterative one since if one of the rows gets bumped then it could subsequently clash with another row with the same column index. 我猜想这将使该问题成为一个迭代问题,因为如果其中一行发生碰撞,则随后可能会与具有相同列索引的另一行发生冲突。
  3. Any unassigned rows should return '-1' 任何未分配的行应返回“ -1”

Thus, the final output should look like: 因此,最终输出应如下所示:

[-1, 0, 1, 2, -1, -1]

One starting point would be to: 一个起点是:

  1. perform an argsort 执行一个argsort
  2. assign unique column indices to rows 为行分配唯一的列索引
  3. remove assigned column indices from each row 从每行中删除分配的列索引
  4. resolve tie-breaks 解决抢七局
  5. repeat step 2-4 until either all column indices are assigned 重复步骤2-4,直到分配了所有列索引

Is there a straightforward way to accomplish this in Python? 有没有简单的方法可以在Python中完成此操作?

This loops over the number of columns, which I assume is smaller than the number of rows: 这会循环遍历列数,我认为该数小于行数:

def find_smallest(a):
    i = np.argmin(a, 1)
    amin = a[np.arange(len(a)), i] # faster than a.min(1)?
    toobig = amin >=5
    i[toobig] = -1
    for u, c in zip(*np.unique(i, return_counts=True)):
        #u, c are the unique values and number of occurrences in `i`
        if c < 2:
            # no repeats of this index
            continue
        mask = i==u # the values in i that match u, which has repeats
        notclosest = np.where(mask)[0].tolist() # indices of the repeats
        notclosest.pop(np.argmin(amin[mask])) # the smallest a value is not a 'repeat', remove it from the list
        i[notclosest] = -1 # and mark all the repeats as -1
    return i

Note, I've used -1 instead of np.nan since an index array is int . 注意,由于索引数组为int ,因此我使用-1而不是np.nan Any reduction in the boolean indexing would help. 布尔索引的任何减少都会有所帮助。 I wanted to use one of the optional additional outputs from np.unique(i) but couldn't. 我想使用np.unique(i)的可选附加输出之一,但不能使用。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM