[英]Find N highest values per row of a 2d numpy array
我有一個 2d numpy 數組a
:
a = np.array(range(0,25)).reshape(5,5)
---
[[ 0 1 2 3 4]
[ 5 6 7 8 9]
[10 11 12 13 14]
[15 16 17 18 19]
[20 21 22 23 24]]
我想找到每行的最大 N 值並將它們替換為 100。我以緩慢的方式執行此操作:
N = 2
idx = a.argsort()
for i in range(a.shape[0]):
a[i,idx[i][::-1][0:N]] = 100
print(a)
---
[[ 0 1 2 100 100]
[ 5 6 7 100 100]
[ 10 11 12 100 100]
[ 15 16 17 100 100]
[ 20 21 22 100 100]]
其實我矩陣的形狀是6000*6000。 如何以更好的方式做到這一點? 喜歡申請嗎?
您可以在此處使用argpartition
:
N=2
ix = a.argpartition(-N)[:,-N:]
a[np.arange(a.shape[0])[:,None], ix] = 100
print(a)
array([[ 0, 1, 2, 100, 100],
[ 5, 6, 7, 100, 100],
[ 10, 11, 12, 100, 100],
[ 15, 16, 17, 100, 100],
[ 20, 21, 22, 100, 100]])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.