[英]Get index for argmin of 2d numpy array
我有一个二维的numpy距离数组:
a = np.array([[2.0, 12.1, 99.2],
[1.0, 1.1, 1.2],
[1.04, 1.05, 1.5],
[4.1, 4.2, 0.2],
[10.0, 11.0, 12.0],
[3.9, 4.9, 4.99]
])
我需要一个函数来评估每一行并返回具有最小值的列的列索引。 当然,可以通过执行以下操作来轻松完成此操作:
np.argmin(a, axis=1)
产生:
[0, 0, 0, 2, 0, 0]
但是,我有一些限制:
因此,最终输出应如下所示:
[-1, 0, 1, 2, -1, -1]
一个起点是:
有没有简单的方法可以在Python中完成此操作?
这会循环遍历列数,我认为该数小于行数:
def find_smallest(a):
i = np.argmin(a, 1)
amin = a[np.arange(len(a)), i] # faster than a.min(1)?
toobig = amin >=5
i[toobig] = -1
for u, c in zip(*np.unique(i, return_counts=True)):
#u, c are the unique values and number of occurrences in `i`
if c < 2:
# no repeats of this index
continue
mask = i==u # the values in i that match u, which has repeats
notclosest = np.where(mask)[0].tolist() # indices of the repeats
notclosest.pop(np.argmin(amin[mask])) # the smallest a value is not a 'repeat', remove it from the list
i[notclosest] = -1 # and mark all the repeats as -1
return i
注意,由于索引数组为int
,因此我使用-1
而不是np.nan
。 布尔索引的任何减少都会有所帮助。 我想使用np.unique(i)
的可选附加输出之一,但不能使用。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.