繁体   English   中英

从Numpy ndarray中提取给定行和列的最快方法是什么?

[英]What is the fastest way to extract given rows and columns from a Numpy ndarray?

我有一个大的(大约14,000 x 14,000)方阵,表示为Numpy ndarray 我希望提取大量的行和列 - 我事先知道的索引,尽管它实际上是所有行和列都不是全零 - 以获得一个新的方阵(大约10,000 x 10,000)。

我发现这样做的最快方法是:

> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 6.19 s per loop

但是,这比进行矩阵乘法所需的时间要慢得多:

> timeit np.multiply(A, A)
1 loops, best of 3: 982 ms per loop

这看起来很奇怪,因为行/列提取和矩阵乘法都需要分配一个新数组(矩阵乘法的结果比提取的结果更大),但矩阵乘法也需要执行额外的计算。

因此,问题是:是否有更有效的方法来执行提取,特别是至少与矩阵乘法一样快?

如果我试图重现你的问题,我看不到这么大的影响。 我注意到,根据您选择的索引数量,索引甚至可以比乘法更快。

>>> import numpy as np
>>> np.__version__
Out[1]: '1.9.0'
>>> N = 14000
>>> A = np.random.random(size=[N, N])

>>> indices = np.sort(np.random.choice(np.arange(N), 0.9*N, replace=False))
>>> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 1.02 s per loop
>>> timeit A.take(indices, axis=0).take(indices, axis=1)
1 loops, best of 3: 1.37 s per loop
>>> timeit np.multiply(A,A)
1 loops, best of 3: 748 ms per loop

>>> indices = np.sort(np.random.choice(np.arange(N), 0.7*N, replace=False))
>>> timeit A[np.ix_(indices, indices)]
1 loops, best of 3: 633 ms per loop
>>> timeit A.take(indices, axis=0).take(indices, axis=1)
1 loops, best of 3: 946 ms per loop
>>> timeit np.multiply(A,A)
1 loops, best of 3: 728 ms per loop

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM