繁体   English   中英

如何在numpy中矩阵乘以3个矩阵?

[英]How to matrix multiply 3 matrices in numpy?

给定一个 3x3 的 F 矩阵和一些 N x 3 的点矩阵,我想有效地计算 E = bT @ F @ a。

F = np.arange(9).reshape(3, 3)
>>> [[0 1 2]
     [3 4 5]
     [6 7 8]]

a = np.array([[1, 2, 1], [3, 4, 1], [5, 6, 1], [7, 8, 1]])
>>>[[1 2 1]
    [3 4 1]
    [5 6 1]
    [7 8 1]]

b = np.array([[10, 20, 1],[30, 40, 1],[50, 60, 1],[70, 80, 1]])
>>>[[10 20  1]
    [30 40  1]
    [50 60  1]
    [70 80  1]]

预期输出为: E = [388 1434 3120 5446]

我可以用一个简单的 for 循环来获得它,但我想用所有 numpy. 我尝试重塑 a 和 b 矩阵,但这并没有完全奏效。

N = b.shape[0] #4
a_reshaped = a.reshape(N, 3, 1)
b_reshaped = b.reshape(1, N, 3)

F_repeated = np.repeat(F[None,:], N, axis=0)
E = b_reshaped @ F_repeated @ a_reshaped 
>>>
[[[ 388]
  [ 788]
  [1188]
  [1588]]

 [[ 714]
  [1434]
  [2154]
  [2874]]

 [[1040]
  [2080]
  [3120]
  [4160]]

 [[1366]
  [2726]
  [4086]
  [5446]]]

如果我然后取对角线值,我会得到预期的结果,这是非常低效的。

有什么建议?

编辑:这是我所描述的 for 循环版本:

E = []
for k in range(N):
    error = b[k].T @ F @ a[k]
    E.append(error)
E = np.array(E)
>>>[388 1434 3120 5446]

根据@haveaball 的评论,您无法从一系列二维矩阵乘法中获得一维数组。

看起来你想要的是:

(b@F@a.T).diagonal()

(或(a@(b@F).T).diagonal()

输出: array([ 388, 1434, 3120, 5446])

我找到了一个避免制作 NxN 矩阵的解决方案:

partial_mult_b = b @ F
E_prime = partial_mult_b * a
E = np.sum(E_prime, axis=1)

首先,“天真”迭代版本 - 清除操作:

In [67]: c = np.zeros(4,int)
In [68]: for i in range(4):
    ...:     c[i] = b[i]@F@a[i]
    ...: 
In [69]: c
Out[69]: array([ 388, 1434, 3120, 5446])

简单的 einsum 版本

In [70]: np.einsum('ij,jk,ik->i',b,F,a)
Out[70]: array([ 388, 1434, 3120, 5446])

批处理 matmul 版本

In [71]: b[:,None,:]@F@a[:,:,None]
Out[71]: 
array([[[ 388]],

       [[1434]],

       [[3120]],

       [[5446]]])
In [72]: (b[:,None,:]@F@a[:,:,None]).squeeze()
Out[72]: array([ 388, 1434, 3120, 5446])

这个:

In [73]: ((b@F)*a).sum(axis=1)
Out[73]: array([ 388, 1434, 3120, 5446])

与将einsum扩展为两个步骤相同:

In [74]: np.einsum('ik,ik->i',np.einsum('ij,jk->ik',b,F),a)
Out[74]: array([ 388, 1434, 3120, 5446])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM