[英]What is the fastest way to extract given rows and columns from a Numpy ndarray?
我有一個大的(大約14,000 x 14,000)方陣,表示為Numpy ndarray
。 我希望提取大量的行和列 - 我事先知道的索引,盡管它實際上是所有行和列都不是全零 - 以獲得一個新的方陣(大約10,000 x 10,000)。
我發現這樣做的最快方法是:
> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 6.19 s per loop
但是,這比進行矩陣乘法所需的時間要慢得多:
> timeit np.multiply(A, A)
1 loops, best of 3: 982 ms per loop
這看起來很奇怪,因為行/列提取和矩陣乘法都需要分配一個新數組(矩陣乘法的結果比提取的結果更大),但矩陣乘法也需要執行額外的計算。
因此,問題是:是否有更有效的方法來執行提取,特別是至少與矩陣乘法一樣快?
如果我試圖重現你的問題,我看不到這么大的影響。 我注意到,根據您選擇的索引數量,索引甚至可以比乘法更快。
>>> import numpy as np
>>> np.__version__
Out[1]: '1.9.0'
>>> N = 14000
>>> A = np.random.random(size=[N, N])
>>> indices = np.sort(np.random.choice(np.arange(N), 0.9*N, replace=False))
>>> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 1.02 s per loop
>>> timeit A.take(indices, axis=0).take(indices, axis=1)
1 loops, best of 3: 1.37 s per loop
>>> timeit np.multiply(A,A)
1 loops, best of 3: 748 ms per loop
>>> indices = np.sort(np.random.choice(np.arange(N), 0.7*N, replace=False))
>>> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 633 ms per loop
>>> timeit A.take(indices, axis=0).take(indices, axis=1)
1 loops, best of 3: 946 ms per loop
>>> timeit np.multiply(A,A)
1 loops, best of 3: 728 ms per loop
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.