[英]efficient loop over numpy array
已经问过这个问题的版本,但我没有找到满意的答案。
问题 :给定一个大的numpy向量,找到重复的向量元素的索引(其变化可以与容差进行比较)。
所以问题是~O(N ^ 2)和内存限制(至少从当前算法的角度来看)。 我想知道为什么我尝试Python的速度比同等的C代码慢100倍或更慢。
import numpy as np
N = 10000
vect = np.arange(float(N))
vect[N/2] = 1
vect[N/4] = 1
dupl = []
print("init done")
counter = 0
for i in range(N):
for j in range(i+1, N):
if vect[i] == vect[j]:
dupl.append(j)
counter += 1
print("counter =", counter)
print(dupl)
# For simplicity, this code ignores repeated indices
# which can be trimmed later. Ref output is
# counter = 3
# [2500, 5000, 5000]
我尝试使用numpy迭代器,但它们更糟糕(~x4-5) http://docs.scipy.org/doc/numpy/reference/arrays.nditer.html
使用N = 10,000我在C中获得0.1秒,在Python中获得12秒(上面的代码),在Python中使用np.nditer获得40秒,使用np.ndindex在Python中获得50秒。 我把它推到N = 160,000,时间按预期按N ^ 2计算。
由于答案已经停止,没有一个完全令人满意,为了记录,我发布了自己的解决方案。
我的理解是,在这种情况下,使得Python变慢的赋值,而不是我最初想到的嵌套循环。 使用库或编译代码消除了分配的需要,并且性能得到显着提高。
from __future__ import print_function
import numpy as np
from numba import jit
N = 10000
vect = np.arange(N, dtype=np.float32)
vect[N/2] = 1
vect[N/4] = 1
dupl = np.zeros(N, dtype=np.int32)
print("init done")
# uncomment to enable compiled function
#@jit
def duplicates(i, counter, dupl, vect):
eps = 0.01
ns = len(vect)
for j in range(i+1, ns):
# replace if to use approx comparison
#if abs(vect[i] - vect[j]) < eps:
if vect[i] == vect[j]:
dupl[counter] = j
counter += 1
return counter
counter = 0
for i in xrange(N):
counter = duplicates(i, counter, dupl, vect)
print("counter =", counter)
print(dupl[0:counter])
测试
# no jit
$ time python array-test-numba.py
init done
counter = 3
[2500 5000 5000]
elapsed 10.135 s
# with jit
$ time python array-test-numba.py
init done
counter = 3
[2500 5000 5000]
elapsed 0.480 s
编译版本的性能(@jit取消注释)接近C代码性能~0.1 - 0.2秒。 也许消除最后一个循环可以进一步提高性能。 当使用eps进行近似比较时,性能差异甚至更强,而编译版本的差异非常小。
# no jit
$ time python array-test-numba.py
init done
counter = 3
[2500 5000 5000]
elapsed 109.218 s
# with jit
$ time python array-test-numba.py
init done
counter = 3
[2500 5000 5000]
elapsed 0.506 s
这是〜200倍的差异。 在实际代码中,我必须将两个循环放在函数中以及使用具有变量类型的函数模板,因此它有点复杂但不是很多。
Python本身是一种高度动态,缓慢的语言。 numpy中的想法是使用矢量化 ,并避免显式循环。 在这种情况下,您可以使用np.equal.outer
。 你可以先开始
a = np.equal.outer(vect, vect)
现在,例如,找到总和:
>>> np.sum(a)
10006
要找到相等的i的索引,你可以这样做
np.fill_diagonal(a, 0)
>>> np.nonzero(np.any(a, axis=0))[0]
array([ 1, 2500, 5000])
定时
def find_vec():
a = np.equal.outer(vect, vect)
s = np.sum(a)
np.fill_diagonal(a, 0)
return np.sum(a), np.nonzero(np.any(a, axis=0))[0]
>>> %timeit find_vec()
1 loops, best of 3: 214 ms per loop
def find_loop():
dupl = []
counter = 0
for i in range(N):
for j in range(i+1, N):
if vect[i] == vect[j]:
dupl.append(j)
counter += 1
return dupl
>>> % timeit find_loop()
1 loops, best of 3: 8.51 s per loop
显而易见的问题是为什么你想以这种方式做到这一点。 NumPy数组旨在成为不透明的数据结构 - 我的意思是NumPy数组旨在在NumPy系统内创建,然后将操作发送到NumPy子系统以提供结果。 即NumPy应该是一个黑盒子,你向其投掷请求并输出结果。
所以考虑到上面的代码,我并不感到惊讶NumPy的表现比糟糕的更糟糕。
以下应该是你想要的,我相信,但做了NumPy方式:
import numpy as np
N = 10000
vect = np.arange(float(N))
vect[N/2] = 1
vect[N/4] = 1
print([np.where(a == vect)[0] for a in vect][1])
# Delivers [1, 2500, 5000]
方法#1
您可以使用triangular matrix
模拟矢量化解决方案的迭代器依赖性条件。 这是基于this post
涉及iterator dependency
乘法的this post
。 对于执行的每个元素的元素单元的平等vect
反对它的所有元素,我们可以使用NumPy broadcasting
。 最后,我们可以使用np.count_nonzero
来获取计数,因为它应该在布尔数组的求和方面非常有效。
那么,我们会有这样的解决方案 -
mask = np.triu(vect[:,None] == vect,1)
counter = np.count_nonzero(mask)
dupl = np.where(mask)[1]
如果您只关心counter
,我们可以再选择两种方法。
方法#2
我们可以避免使用三角矩阵并简单地得到整个计数并简单地减去对角元素的贡献,并且只考虑将剩余计数减半,只考虑上三角区域中较低的一个,因为任何一个的贡献都是相同的。
那么,我们会有一个像这样的修改过的解决方案 -
counter = (np.count_nonzero(vect[:,None] == vect) - vect.size)//2
方法#3
这是一种完全不同的方法,它使用每个独特元素的计数对最终总数的累计贡献这一事实。
所以,考虑到这个想法,我们会有第三种方法 -
count = np.bincount(vect) # OR np.unique(vect,return_counts=True)[1]
idx = count[count>1]
id_arr = np.ones(idx.sum(),dtype=int)
id_arr[0] = 0
id_arr[idx[:-1].cumsum()] = -idx[:-1]+1
counter = np.sum(id_arr.cumsum())
我想知道为什么我尝试Python的速度比同等的C代码慢100倍或更慢。
因为Python程序通常比C程序慢100倍。
您可以在C中实现关键代码路径并提供Python-C绑定,也可以更改算法。 您可以使用将数组从值反转为索引的dict
来编写O(N)版本。
import numpy as np
N = 10000
vect = np.arange(float(N))
vect[N/2] = 1
vect[N/4] = 1
dupl = {}
print("init done")
counter = 0
for i in range(N):
e = dupl.get(vect[i], None)
if e is None:
dupl[vect[i]] = [i]
else:
e.append(i)
counter += 1
print("counter =", counter)
print([(k, v) for k, v in dupl.items() if len(v) > 1])
编辑:
如果您需要针对具有abs(vect [i] - vect [j])<eps的eps进行测试,则可以将值标准化为eps
abs(vect[i] - vect[j]) < eps ->
abs(vect[i] - vect[j]) / eps < (eps / eps) ->
abs(vect[i]/eps - vect[j]/eps) < 1
int(abs(vect[i]/eps - vect[j]/eps)) = 0
像这样:
import numpy as np
N = 10000
vect = np.arange(float(N))
vect[N/2] = 1
vect[N/4] = 1
dupl = {}
print("init done")
counter = 0
eps = 0.01
for i in range(N):
k = int(vect[i] / eps)
e = dupl.get(k, None)
if e is None:
dupl[k] = [i]
else:
e.append(i)
counter += 1
print("counter =", counter)
print([(k, v) for k, v in dupl.items() if len(v) > 1])
这个使用numpy_indexed包的解决方案具有复杂性n Log n,并且是完全向量化的; 所以很可能与C表现没有太大的不同。
import numpy_indexed as npi
dpl = np.flatnonzero(npi.multiplicity(vect) > 1)
作为Ami Tavory答案的替代方案,您可以使用集合包中的计数器来检测重复项。 在我的电脑上,它似乎更快。 请参阅下面的函数,它也可以找到不同的重复项。
import collections
import numpy as np
def find_duplicates_original(x):
d = []
for i in range(len(x)):
for j in range(i + 1, len(x)):
if x[i] == x[j]:
d.append(j)
return d
def find_duplicates_outer(x):
a = np.equal.outer(x, x)
np.fill_diagonal(a, 0)
return np.flatnonzero(np.any(a, axis=0))
def find_duplicates_counter(x):
counter = collections.Counter(x)
values = (v for v, c in counter.items() if c > 1)
return {v: np.flatnonzero(x == v) for v in values}
n = 10000
x = np.arange(float(n))
x[n // 2] = 1
x[n // 4] = 1
>>>> find_duplicates_counter(x)
{1.0: array([ 1, 2500, 5000], dtype=int64)}
>>>> %timeit find_duplicates_original(x)
1 loop, best of 3: 12 s per loop
>>>> %timeit find_duplicates_outer(x)
10 loops, best of 3: 84.3 ms per loop
>>>> %timeit find_duplicates_counter(x)
1000 loops, best of 3: 1.63 ms per loop
这比8秒运行,而代码为18秒,不使用任何奇怪的库。 它类似于@ vs0的方法,但我更喜欢defaultdict
。 它应该约为O(N)。
from collections import defaultdict
dupl = []
counter = 0
indexes = defaultdict(list)
for i, e in enumerate(vect):
indexes[e].append(i)
if len(indexes[e]) > 1:
dupl.append(i)
counter += 1
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.