[英]Combining element-wise and matrix multiplication with multi-dimensional arrays in NumPy
我有两个多维 NumPy 数组, A
和B
, A.shape = (K, d, N)
和B.shape = (K, N, d)
。 我想在轴 0 ( K
) 上执行元素操作,该操作是在轴 1 和 2 ( d, N
和N, d
) 上的矩阵乘法。 所以结果应该是一个多维数组C
其中C.shape = (K, d, d)
,所以C[k] = np.dot(A[k], B[k])
。 一个简单的实现看起来像这样:
C = np.vstack([np.dot(A[k], B[k])[np.newaxis, :, :] for k in xrange(K)])
但是这个实现很慢。 稍微快一点的方法如下所示:
C = np.dot(A, B)[:, :, 0, :]
它在多维数组上使用np.dot
的默认行为,给我一个形状为(K, d, K, d)
的数组。 但是,这种方法计算所需的答案K
次(沿轴 2 的每个条目都相同)。 渐近地它会比第一种方法慢,但开销要少得多。 我也知道以下方法:
from numpy.core.umath_tests import matrix_multiply
C = matrix_multiply(A, B)
但我不保证此功能将可用。 因此,我的问题是,NumPy 是否提供了有效执行此操作的标准方法? 一般适用于多维数组的答案将是完美的,但仅针对这种情况的答案也会很棒。
编辑:正如@Juh_ 所指出的,第二种方法是不正确的。 正确的版本是:
C = np.dot(A, B).diagonal(axis1=0, axis2=2).transpose(2, 0, 1)
但是增加的开销使它比第一种方法慢,即使对于小矩阵也是如此。 最后一种方法是在我所有的时序测试中,无论是小矩阵还是大矩阵,都遥遥领先。 如果没有更好的解决方案出现,我现在强烈考虑使用它,即使这意味着将numpy.core.umath_tests
库(用 C 编写)复制到我的项目中。
您的问题的可能解决方案是:
C = np.sum(A[:,:,:,np.newaxis]*B[:,np.newaxis,:,:],axis=2)
然而:
顺便说一句,请注意:
C = np.dot(A, B)[:, :, 0, :]
没有给出正确的结果。 它让我被骗了,因为我首先通过将结果与 np.dot 命令给出的结果进行比较来检查我的方法。
我的项目中有同样的问题。 我能想到的最好的方法是,我认为它比使用vstack
快一点(可能快 10%):
K, d, N = A.shape
C = np.empty((K, d, d))
for k in xrange(K):
C[k] = np.dot(A[k], B[k])
我很想看到更好的解决方案,我不太明白人们会如何使用tensordot
来做到这一点。
一个非常灵活、紧凑且快速的解决方案:
C = np.einsum('Kab,Kbc->Kac', A, B, optimize=True)
确认:
import numpy as np
K = 10
d = 5
N = 3
A = np.random.rand(K,d,N)
B = np.random.rand(K,N,d)
C_old = np.dot(A, B).diagonal(axis1=0, axis2=2).transpose(2, 0, 1)
C_new = np.einsum('Kab,Kbc->Kac', A, B)
print(np.max(C_old-C_new)) # should be 0 or a very small number
对于大型多维数组,可选参数optimize=True
可以为您节省大量时间。 您可以在此处了解einsum :
https://ajcr.net/Basic-guide-to-einsum/
https://rockt.github.io/2018/04/30/einsum
https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
引用:
爱因斯坦求和约定可用于计算许多多维线性代数数组运算。 einsum提供了一种简洁的方式来表示这些。 这些操作的非详尽列表是:
数组的跟踪, numpy.trace 。
返回对角线numpy.diag 。
数组轴求和, numpy.sum 。
换位和排列, numpy.transpose 。
矩阵乘法和点积, numpy.matmul numpy.dot 。
矢量内积和外积, numpy.inner numpy.outer 。
广播,元素和标量乘法, numpy.multiply 。
张量收缩, numpy.tensordot 。
链式数组操作,以高效的计算顺序, numpy.einsum_path 。
你可以做
np.matmul(A, B)
查看https://numpy.org/doc/stable/reference/generated/numpy.matmul.html 。
对于足够大的K
应该比 einsum 快。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.