簡體   English   中英

從Numpy ndarray中提取給定行和列的最快方法是什么?

[英]What is the fastest way to extract given rows and columns from a Numpy ndarray?

我有一個大的(大約14,000 x 14,000)方陣,表示為Numpy ndarray 我希望提取大量的行和列 - 我事先知道的索引,盡管它實際上是所有行和列都不是全零 - 以獲得一個新的方陣(大約10,000 x 10,000)。

我發現這樣做的最快方法是:

> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 6.19 s per loop

但是,這比進行矩陣乘法所需的時間要慢得多:

> timeit np.multiply(A, A)
1 loops, best of 3: 982 ms per loop

這看起來很奇怪,因為行/列提取和矩陣乘法都需要分配一個新數組(矩陣乘法的結果比提取的結果更大),但矩陣乘法也需要執行額外的計算。

因此,問題是:是否有更有效的方法來執行提取,特別是至少與矩陣乘法一樣快?

如果我試圖重現你的問題,我看不到這么大的影響。 我注意到,根據您選擇的索引數量,索引甚至可以比乘法更快。

>>> import numpy as np
>>> np.__version__
Out[1]: '1.9.0'
>>> N = 14000
>>> A = np.random.random(size=[N, N])

>>> indices = np.sort(np.random.choice(np.arange(N), 0.9*N, replace=False))
>>> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 1.02 s per loop
>>> timeit A.take(indices, axis=0).take(indices, axis=1)
1 loops, best of 3: 1.37 s per loop
>>> timeit np.multiply(A,A)
1 loops, best of 3: 748 ms per loop

>>> indices = np.sort(np.random.choice(np.arange(N), 0.7*N, replace=False))
>>> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 633 ms per loop
>>> timeit A.take(indices, axis=0).take(indices, axis=1)
1 loops, best of 3: 946 ms per loop
>>> timeit np.multiply(A,A)
1 loops, best of 3: 728 ms per loop

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM