简体   繁体   English

你会如何编写一个计算量较小的 numpy.where(np.ones(shape))

[英]How would you write an less computationally intensive equivalent of numpy.where(np.ones(shape))

I want to get a list of elements for an array of a given shape.我想获取给定形状数组的元素列表。 I found one easy way to do that:我找到了一种简单的方法来做到这一点:

import numpy as np
shape = (3,3)
elements = np.where(np.ones(shape))

the result is结果是

>>> elements
(array([0, 0, 0, 1, 1, 1, 2, 2, 2]), array([0, 1, 2, 0, 1, 2, 0, 1, 2]))

This is the expected behaviour.这是预期的行为。 However it doesn't seem to be the most compute-efficient way.然而,它似乎并不是最高效的计算方式。 f shape is huge, then np.where can be quite sluggish. f 形状很大,然后 np.where 可能会非常缓慢。 I am looking for a more compute-efficient solution.我正在寻找一种计算效率更高的解决方案。 Any idea?任何的想法?

Based on the comments I have received, I implemented 3 ways to get the same result and tested their performance.根据我收到的评论,我实现了 3 种方法来获得相同的结果并测试它们的性能。

import timeit
import numpy as np

def with_where(a):
    shape = a.shape
    return np.where(np.ones(shape))
def with_mgrid(a):
    shape = a.shape
    grid_shape = (len(shape), np.prod(shape))
    return np.mgrid[0:shape[0],0:shape[1]].reshape(grid_shape)
def with_repeat(a):
    shape = a.shape
    np.repeat(np.arange(shape[0]), shape[1]), np.tile(np.arange(shape[1]), shape[0])

a1 = np.ones((1,1))
a10 = np.ones((10,10))
a100 = np.ones((100,100))
a1000 = np.ones((1000,1000))
a10000 = np.ones((10000,10000))

Then I ran %timeit in ipython然后我在 ipython 中运行 %timeit

%timeit with_where(a1)
%timeit with_where(a10)
%timeit with_where(a100)
%timeit with_where(a1000)
%timeit with_where(a10000)

11.1 µs ± 163 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
12.7 µs ± 39.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
146 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
16.2 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.49 s ± 58.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit with_mgrid(a1)
%timeit with_mgrid(a10)
%timeit with_mgrid(a100)
%timeit with_mgrid(a1000)
%timeit with_mgrid(a10000)

50.2 µs ± 2.32 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
45.9 µs ± 989 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
75.1 µs ± 1.71 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
6.17 ms ± 54.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
2.1 s ± 40.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit with_repeat(a1)
%timeit with_repeat(a10)
%timeit with_repeat(a100)
%timeit with_repeat(a1000)
%timeit with_repeat(a10000)

23.3 µs ± 931 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
31 µs ± 739 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
66 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
4.41 ms ± 127 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
1.05 s ± 22.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

so for large arrays, the method with np.where is about 2x as slow as the fastest method.所以对于大型数组,使用 np.where 的方法大约是最快方法的 2 倍。 This is not as bad as I thought.这并不像我想象的那么糟糕。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM