簡體   English   中英

在二維 numpy 數組的每行中找到 N 個最高值

[英]Find N highest values per row of a 2d numpy array

我有一個 2d numpy 數組a

a = np.array(range(0,25)).reshape(5,5)
---
[[ 0  1  2  3  4]
 [ 5  6  7  8  9]
 [10 11 12 13 14]
 [15 16 17 18 19]
 [20 21 22 23 24]]

我想找到每行的最大 N 值並將它們替換為 100。我以緩慢的方式執行此操作:

N = 2
idx = a.argsort()
for i in range(a.shape[0]):
    a[i,idx[i][::-1][0:N]] = 100
print(a)
---
[[  0   1   2 100 100]
 [  5   6   7 100 100]
 [ 10  11  12 100 100]
 [ 15  16  17 100 100]
 [ 20  21  22 100 100]]

其實我矩陣的形狀是6000*6000。 如何以更好的方式做到這一點? 喜歡申請嗎?

您可以在此處使用argpartition

N=2
ix = a.argpartition(-N)[:,-N:]
a[np.arange(a.shape[0])[:,None], ix] = 100

print(a)
array([[  0,   1,   2, 100, 100],
       [  5,   6,   7, 100, 100],
       [ 10,  11,  12, 100, 100],
       [ 15,  16,  17, 100, 100],
       [ 20,  21,  22, 100, 100]])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM