[英]How would you write an less computationally intensive equivalent of numpy.where(np.ones(shape))
I want to get a list of elements for an array of a given shape.我想获取给定形状数组的元素列表。 I found one easy way to do that:
我找到了一种简单的方法来做到这一点:
import numpy as np
shape = (3,3)
elements = np.where(np.ones(shape))
the result is结果是
>>> elements
(array([0, 0, 0, 1, 1, 1, 2, 2, 2]), array([0, 1, 2, 0, 1, 2, 0, 1, 2]))
This is the expected behaviour.这是预期的行为。 However it doesn't seem to be the most compute-efficient way.
然而,它似乎并不是最高效的计算方式。 f shape is huge, then np.where can be quite sluggish.
f 形状很大,然后 np.where 可能会非常缓慢。 I am looking for a more compute-efficient solution.
我正在寻找一种计算效率更高的解决方案。 Any idea?
任何的想法?
Based on the comments I have received, I implemented 3 ways to get the same result and tested their performance.根据我收到的评论,我实现了 3 种方法来获得相同的结果并测试它们的性能。
import timeit
import numpy as np
def with_where(a):
shape = a.shape
return np.where(np.ones(shape))
def with_mgrid(a):
shape = a.shape
grid_shape = (len(shape), np.prod(shape))
return np.mgrid[0:shape[0],0:shape[1]].reshape(grid_shape)
def with_repeat(a):
shape = a.shape
np.repeat(np.arange(shape[0]), shape[1]), np.tile(np.arange(shape[1]), shape[0])
a1 = np.ones((1,1))
a10 = np.ones((10,10))
a100 = np.ones((100,100))
a1000 = np.ones((1000,1000))
a10000 = np.ones((10000,10000))
Then I ran %timeit in ipython然后我在 ipython 中运行 %timeit
%timeit with_where(a1)
%timeit with_where(a10)
%timeit with_where(a100)
%timeit with_where(a1000)
%timeit with_where(a10000)
11.1 µs ± 163 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.7 µs ± 39.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
146 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
16.2 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.49 s ± 58.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit with_mgrid(a1)
%timeit with_mgrid(a10)
%timeit with_mgrid(a100)
%timeit with_mgrid(a1000)
%timeit with_mgrid(a10000)
50.2 µs ± 2.32 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
45.9 µs ± 989 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
75.1 µs ± 1.71 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
6.17 ms ± 54.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.1 s ± 40.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit with_repeat(a1)
%timeit with_repeat(a10)
%timeit with_repeat(a100)
%timeit with_repeat(a1000)
%timeit with_repeat(a10000)
23.3 µs ± 931 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
31 µs ± 739 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
66 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
4.41 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.05 s ± 22.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
so for large arrays, the method with np.where is about 2x as slow as the fastest method.所以对于大型数组,使用 np.where 的方法大约是最快方法的 2 倍。 This is not as bad as I thought.
这并不像我想象的那么糟糕。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.