繁体   English   中英

获取二维numpy数组的argmin的索引

[英]Get index for argmin of 2d numpy array

我有一个二维的numpy距离数组:

a = np.array([[2.0, 12.1, 99.2], 
              [1.0, 1.1, 1.2], 
              [1.04, 1.05, 1.5], 
              [4.1, 4.2, 0.2], 
              [10.0, 11.0, 12.0], 
              [3.9, 4.9, 4.99] 
             ])

我需要一个函数来评估每一行并返回具有最小值的列的列索引。 当然,可以通过执行以下操作来轻松完成此操作:

np.argmin(a, axis=1) 

产生:

[0, 0, 0, 2, 0, 0]

但是,我有一些限制:

  1. argmin评估应仅考虑小于5.0的距离。 如果一行中的距离均不小于5.0,则返回“ -1”作为索引
  2. 为所有行返回的索引列表必须是唯一的(即,如果两个或更多行以相同的列索引结尾,那么与给定列索引的距离较小的行将被赋予优先级,而所有其他行必须返回不同的列指数)。 我猜想这将使该问题成为一个迭代问题,因为如果其中一行发生碰撞,则随后可能会与具有相同列索引的另一行发生冲突。
  3. 任何未分配的行应返回“ -1”

因此,最终输出应如下所示:

[-1, 0, 1, 2, -1, -1]

一个起点是:

  1. 执行一个argsort
  2. 为行分配唯一的列索引
  3. 从每行中删除分配的列索引
  4. 解决抢七局
  5. 重复步骤2-4,直到分配了所有列索引

有没有简单的方法可以在Python中完成此操作?

这会循环遍历列数,我认为该数小于行数:

def find_smallest(a):
    i = np.argmin(a, 1)
    amin = a[np.arange(len(a)), i] # faster than a.min(1)?
    toobig = amin >=5
    i[toobig] = -1
    for u, c in zip(*np.unique(i, return_counts=True)):
        #u, c are the unique values and number of occurrences in `i`
        if c < 2:
            # no repeats of this index
            continue
        mask = i==u # the values in i that match u, which has repeats
        notclosest = np.where(mask)[0].tolist() # indices of the repeats
        notclosest.pop(np.argmin(amin[mask])) # the smallest a value is not a 'repeat', remove it from the list
        i[notclosest] = -1 # and mark all the repeats as -1
    return i

注意,由于索引数组为int ,因此我使用-1而不是np.nan 布尔索引的任何减少都会有所帮助。 我想使用np.unique(i)的可选附加输出之一,但不能使用。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM