[英]Efficient way to compute an array matrix multiplication for a batch of arrays
I want to parallelize the following problem.我想并行化以下问题。 Given an array
w
with shape (dim1,)
and a matrix A
with shape (dim1, dim2)
, I want each row of A
to be multiplied for the corresponding element of w
.给定一个形状为
(dim1,)
的数组w
和一个形状为(dim1, dim2)
的矩阵A
,我希望A
的每一行都与w
的相应元素相乘。 That's quite trivial.这很微不足道。
However, I want to do that for a bunch of arrays w
and finally sum the result.但是,我想为一堆 arrays
w
这样做,最后对结果求和。 So that, to avoid the for loop, I created the matrix W
with shape (n_samples, dim1)
, and I used the np.einsum
function in the following way:因此,为了避免 for 循环,我创建了形状为
(n_samples, dim1)
的矩阵W
,并按以下方式使用了np.einsum
function :
x = np.einsum('ji, ik -> jik', W, A))
r = x.sum(axis=0)
where the shape of x
is (n_samples, dim1, dim2)
and the final sum has shape (dim1, dim2)
.其中
x
的形状为(n_samples, dim1, dim2)
,最终总和的形状为(dim1, dim2)
。
I noticed that np.einsum
is quite slow for a large matrix A
.我注意到
np.einsum
对于大型矩阵A
非常慢。 Is there any more efficient way of solving this problem?有没有更有效的方法来解决这个问题? I also wanted to try with
np.tensordot
but maybe this is not the case.我也想尝试使用
np.tensordot
但也许不是这样。
Thank you:-)谢谢:-)
In [455]: W = np.arange(1,7).reshape(2,3); A = np.arange(1,13).reshape(3,4)
Your calculation:你的计算:
In [463]: x = np.einsum('ji, ik -> jik', W, A)
...: r = x.sum(axis=0)
In [464]: r
Out[464]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])
As noted in a comment, einsum
can perform the sum on j
:如评论中所述,
einsum
可以对j
执行求和:
In [465]: np.einsum('ji, ik -> ik', W, A)
Out[465]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])
And since j
only occurs in A
, we can sum on A
first:由于
j
只出现在A
中,我们可以先对A
求和:
In [466]: np.sum(W,axis=0)[:,None]*A
Out[466]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])
This doesn't involve a sum-of-products, so isn't matrix multiplication.这不涉及乘积之和,矩阵乘法也不涉及。
Or doing the sum after multiplication:或在乘法后求和:
In [475]: (W[:,:,None]*A).sum(axis=0)
Out[475]:
array([[ 5, 10, 15, 20],
[ 35, 42, 49, 56],
[ 81, 90, 99, 108]])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.