[英]Custom non-linear matrix multiplication in NumPy
假设我必须矩阵U
和W
:
U = np.arange(6*2).reshape((6,2))
W = np.arange(5*2).reshape((5,2))
对于标准的线性乘法,我可以这样做:
U @ W.T
array([[ 1, 3, 5, 7, 9],
[ 3, 13, 23, 33, 43],
[ 5, 23, 41, 59, 77],
[ 7, 33, 59, 85, 111],
[ 9, 43, 77, 111, 145],
[ 11, 53, 95, 137, 179]])
但我也可以(技术上)定义一个线性乘法函数,按列执行此操作并在 for 循环中求和:
def mult(U, W, i):
return U[:, [i]] @ W.T[[i],:]
sum([mult(U, W, i) for i in range(2)]) #1
array([[ 1, 3, 5, 7, 9],
[ 3, 13, 23, 33, 43],
[ 5, 23, 41, 59, 77],
[ 7, 33, 59, 85, 111],
[ 9, 43, 77, 111, 145],
[ 11, 53, 95, 137, 179]])
现在假设mult()
不再是线性的,它是非线性的、自定义的,例如:
def mult(U, W, i):
return (U[:, [i]] @ W.T[[i],:]) * np.cos(U[:, [i]] @ W.T[[i],:])
sum([mult(U, W, i) for i in range(2)]) #2
您可以验证这与(U @ WT) * np.cos(U @ WT)
。 但我想知道是否有一种更紧凑的书写方式#2
,就像如果mult()
是线性的,有一种更紧凑的书写方式#1
一样。 效率会很好,但我不是在处理巨大的矩阵。
@
,就像np.dot
是一个矩阵乘法,涉及我们通常所说的积和。 这是一个基本的线性代数运算,并且np.matmul
使用高效的编译库来执行此操作(在可能的情况下)。
您的sum([mult(...))
正在这样做 - 获取行/列产品并将它们相加。 编译后的代码可能使用在迭代c
或Fortran
中运行良好的更有效的方法。
您的mult
函数可以使用广播的元素乘法。 对于一个i
:
In [43]: i=1;U[:, [i]] @ W.T[[i],:] # (6,1) @ (1,5) => (6,5)
Out[43]:
array([[ 1, 3, 5, 7, 9],
[ 3, 9, 15, 21, 27],
[ 5, 15, 25, 35, 45],
[ 7, 21, 35, 49, 63],
[ 9, 27, 45, 63, 81],
[11, 33, 55, 77, 99]])
In [44]: i=1;U[:, [i]] * W.T[[i],:]
Out[44]:
array([[ 1, 3, 5, 7, 9],
[ 3, 9, 15, 21, 27],
[ 5, 15, 25, 35, 45],
[ 7, 21, 35, 49, 63],
[ 9, 27, 45, 63, 81],
[11, 33, 55, 77, 99]])
如果没有列表理解,这可以写成:
In [46]: (U[:,None,:]*W[None,:,:]).shape
Out[46]: (6, 5, 2)
In [47]: (U[:,None,:]*W[None,:,:]).sum(axis=2)
Out[47]:
array([[ 1, 3, 5, 7, 9],
[ 3, 13, 23, 33, 43],
[ 5, 23, 41, 59, 77],
[ 7, 33, 59, 85, 111],
[ 9, 43, 77, 111, 145],
[ 11, 53, 95, 137, 179]])
至于你的 `np.cos 版本:
In [48]: def mult(U, W, i):
...: return (U[:, [i]] @ W.T[[i],:]) * np.cos(U[:, [i]] @ W.T[[i],:])
...: sum([mult(U, W, i) for i in range(2)]) #2
Out[48]:
array([[ 5.40302306e-01, -2.96997749e+00, 1.41831093e+00,
5.27731578e+00, -8.20017236e+00],
[-2.96997749e+00, -1.08147468e+01, -1.25593190e+01,
-1.37606696e+00, -2.32102995e+01],
[ 1.41831093e+00, -1.25593190e+01, 9.45751861e+00,
-2.14489310e+01, 5.03346370e+01],
[ 5.27731578e+00, -1.37606696e+00, -2.14489310e+01,
1.01223418e+01, 3.13845563e+01],
[-8.20017236e+00, -2.32102995e+01, 5.03346370e+01,
3.13845563e+01, 8.79904273e+01],
[ 4.86826779e-02, 7.72350858e+00, -2.54605509e+01,
-5.95298563e+01, -4.88871235e+00]])
我可以使用相同的外部/总和格式:
In [49]: (U[:,None,:]*W[None,:,:]*np.cos(U[:,None,:]*W[None,:,:])).sum(axis=2)
Out[49]:
array([[ 5.40302306e-01, -2.96997749e+00, 1.41831093e+00,
5.27731578e+00, -8.20017236e+00],
[-2.96997749e+00, -1.08147468e+01, -1.25593190e+01,
-1.37606696e+00, -2.32102995e+01],
[ 1.41831093e+00, -1.25593190e+01, 9.45751861e+00,
-2.14489310e+01, 5.03346370e+01],
[ 5.27731578e+00, -1.37606696e+00, -2.14489310e+01,
1.01223418e+01, 3.13845563e+01],
[-8.20017236e+00, -2.32102995e+01, 5.03346370e+01,
3.13845563e+01, 8.79904273e+01],
[ 4.86826779e-02, 7.72350858e+00, -2.54605509e+01,
-5.95298563e+01, -4.88871235e+00]])
由于外积被使用了两次,我们可以使用一个临时变量:
In [51]: temp=U[:,None,:]*W[None,:,:];
(temp*np.cos(temp)).sum(axis=2)
Out[51]:
array([[ 5.40302306e-01, -2.96997749e+00, 1.41831093e+00,
5.27731578e+00, -8.20017236e+00],
[-2.96997749e+00, -1.08147468e+01, -1.25593190e+01,
-1.37606696e+00, -2.32102995e+01],
[ 1.41831093e+00, -1.25593190e+01, 9.45751861e+00,
-2.14489310e+01, 5.03346370e+01],
[ 5.27731578e+00, -1.37606696e+00, -2.14489310e+01,
1.01223418e+01, 3.13845563e+01],
[-8.20017236e+00, -2.32102995e+01, 5.03346370e+01,
3.13845563e+01, 8.79904273e+01],
[ 4.86826779e-02, 7.72350858e+00, -2.54605509e+01,
-5.95298563e+01, -4.88871235e+00]])
您不能简单地互换乘法和求和步骤这一事实是基本代数的问题。
要得到
a1*b1 + a2*b2
从
(a1+a2)*(b1+b2) => a1*b1 + a1*b2 + a2*b1 + a2*b2
a1*b2 + a2*b1
项的总和必须为零,就像复数的大小一样:
In [53]: (1+4j)*(1-4j)
Out[53]: (17+0j) # (1+16)
乘积之和通常不能转换为和的乘积。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.