[英]Vectorize finding closest value in an array for each element in another array
[英]Vectorize addition into array indexed by another array
我正在尝试获取以下循环的快速矢量化版本:
for i in xrange(N1):
A[y[i]] -= B[i,:]
这里A.shape = (N2,N3)
, y.shape = (N1)
, y
取[0,N2[
, B.shape = (N1,N3)
。 您可以认为y
的条目是A
行的索引。 N1
很大, N2
很小, N3
很小。
我以为干嘛
A[y] -= B
会起作用,但是问题是y
中有重复的条目,这不能做正确的事(即,如果y=[1,1]
则A[1]
仅被添加一次,而不是两次)。 同样,这似乎没有比未向量化的for循环更快。
有更好的方法吗?
编辑:YXD 将此评论链接到最初似乎符合要求的评论中。 看来您可以完全按照我的意愿做
np.subtract.at(A, y, B)
并且确实可以,但是当我尝试运行它时,它的速度明显比未向量化的版本慢 。 因此,问题仍然存在:是否有更高效的方法?
EDIT2:一个例子,使事情具体:
n1,n2,n3 = 10000, 10, 500
A = np.random.rand(n2,n3)
y = np.random.randint(n2, size=n1)
B = np.random.rand(n1,n3)
在ipython中使用%timeit
运行时,for循环在我的机器上给出:
10 loops, best of 3: 19.4 ms per loop
最后, subtract.at
版本会为A
产生相同的值,但要慢得多:
1 loops, best of 3: 444 ms per loop
原始的基于for循环方法的代码看起来像这样-
def for_loop(A):
N1 = B.shape[0]
for i in xrange(N1):
A[y[i]] -= B[i,:]
return A
情况1
如果n2 >> n3,我建议采用这种向量化方法-
def bincount_vectorized(A):
n3 = A.shape[1]
nrows = y.max()+1
id = y[:,None] + nrows*np.arange(n3)
A[:nrows] -= np.bincount(id.ravel(),B.ravel()).reshape(n3,nrows).T
return A
运行时测试-
In [203]: n1,n2,n3 = 10000, 500, 10
...: A = np.random.rand(n2,n3)
...: y = np.random.randint(n2, size=n1)
...: B = np.random.rand(n1,n3)
...:
...: # Make copies
...: Acopy1 = A.copy()
...: Acopy2 = A.copy()
...:
In [204]: %timeit for_loop(Acopy1)
10 loops, best of 3: 19 ms per loop
In [205]: %timeit bincount_vectorized(Acopy2)
1000 loops, best of 3: 779 µs per loop
情况#2
如果n2 << n3,则可以提出一种改进的for循环方法,其循环复杂度较低-
def for_loop_v2(A):
n2 = A.shape[0]
for i in range(n2):
A[i] -= np.einsum('ij->j',B[y==i]) # OR (B[y==i]).sum(0)
return A
运行时测试-
In [206]: n1,n2,n3 = 10000, 10, 500
...: A = np.random.rand(n2,n3)
...: y = np.random.randint(n2, size=n1)
...: B = np.random.rand(n1,n3)
...:
...: # Make copies
...: Acopy1 = A.copy()
...: Acopy2 = A.copy()
...:
In [207]: %timeit for_loop(Acopy1)
10 loops, best of 3: 24.2 ms per loop
In [208]: %timeit for_loop_v2(Acopy2)
10 loops, best of 3: 20.3 ms per loop
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.