繁体   English   中英

numpy dot() 和 Python 3.5+ 矩阵乘法之间的区别@

[英]Difference between numpy dot() and Python 3.5+ matrix multiplication @

我最近迁移到 Python 3.5 并注意到新的矩阵乘法运算符 (@)有时与numpy 点运算符的行为不同。 例如,对于 3d 数组:

import numpy as np

a = np.random.rand(8,13,13)
b = np.random.rand(8,13,13)
c = a @ b  # Python 3.5+
d = np.dot(a, b)

@运算符返回一个形状数组:

c.shape
(8, 13, 13)

np.dot()函数返回:

d.shape
(8, 13, 8, 13)

如何使用 numpy dot 重现相同的结果? 还有其他显着差异吗?

@运算符调用数组的__matmul__方法,而不是dot 此方法也作为函数np.matmul存在于 API 中。

>>> a = np.random.rand(8,13,13)
>>> b = np.random.rand(8,13,13)
>>> np.matmul(a, b).shape
(8, 13, 13)

从文档中:

matmul在两个重要方面与dot不同。

  • 不允许乘以标量。
  • 矩阵堆栈一起广播,就好像矩阵是元素一样。

最后一点清楚地表明dotmatmul方法在传递 3D(或更高维)数组时表现不同。 从文档中引用更多:

对于matmul

如果任一参数为 ND,N > 2,则将其视为驻留在最后两个索引中的矩阵堆栈并相应地广播。

对于np.dot

对于二维数组,它相当于矩阵乘法,对于一维数组,它相当于向量的内积(没有复共轭)。 对于 N 维,它是 a 的最后一个轴和 b 的倒数第二个轴的和积

仅供参考, @及其 numpy 等效项dotmatmul都同样快。 (使用perfplot创建的绘图,我的一个项目。)

在此处输入图像描述

重现情节的代码:

import perfplot
import numpy


def setup(n):
    A = numpy.random.rand(n, n)
    x = numpy.random.rand(n)
    return A, x


def at(data):
    A, x = data
    return A @ x


def numpy_dot(data):
    A, x = data
    return numpy.dot(A, x)


def numpy_matmul(data):
    A, x = data
    return numpy.matmul(A, x)


perfplot.show(
    setup=setup,
    kernels=[at, numpy_dot, numpy_matmul],
    n_range=[2 ** k for k in range(15)],
)

@ajcr 的答案解释了dotmatmul (由@符号调用)的不同之处。 通过看一个简单的例子,可以清楚地看到两者在“矩阵堆栈”或张量上的行为有何不同。

为了澄清差异,采用 4x4 数组并返回dot积和matmul积以及 3x4x2“矩阵堆栈”或张量。

import numpy as np
fourbyfour = np.array([
                       [1,2,3,4],
                       [3,2,1,4],
                       [5,4,6,7],
                       [11,12,13,14]
                      ])


threebyfourbytwo = np.array([
                             [[2,3],[11,9],[32,21],[28,17]],
                             [[2,3],[1,9],[3,21],[28,7]],
                             [[2,3],[1,9],[3,21],[28,7]],
                            ])

print('4x4*3x4x2 dot:\n {}\n'.format(np.dot(fourbyfour,threebyfourbytwo)))
print('4x4*3x4x2 matmul:\n {}\n'.format(np.matmul(fourbyfour,threebyfourbytwo)))

每个操作的产品如下所示。 注意点积是怎样的,

...a 的最后一个轴和 b 的倒数第二个轴的和积

以及如何将矩阵广播在一起形成矩阵乘积。

4x4*3x4x2 dot:
 [[[232 152]
  [125 112]
  [125 112]]

 [[172 116]
  [123  76]
  [123  76]]

 [[442 296]
  [228 226]
  [228 226]]

 [[962 652]
  [465 512]
  [465 512]]]

4x4*3x4x2 matmul:
 [[[232 152]
  [172 116]
  [442 296]
  [962 652]]

 [[125 112]
  [123  76]
  [228 226]
  [465 512]]

 [[125 112]
  [123  76]
  [228 226]
  [465 512]]]

在数学中,我认为 numpy 中的更有意义

(a,b)_{i,j,k,a,b,c} = 公式

因为当 a 和 b 是向量时它给出点积,或者当 a 和 b 是矩阵时给出矩阵乘法


对于numpy中的matmul操作,它由结果的部分组成,可以定义为

matmul (a,b)_{i,j,k,c} = 公式


因此,您可以看到matmul(a,b)返回一个形状较小的数组,它具有较小的内存消耗并且在应用程序中更有意义。 特别是结合 广播,可以得到

matmul (a,b)_{i,j,k,l} = 公式

例如。


从以上两个定义中,可以看出使用这两个操作的要求。 假设a.shape=(s1,s2,s3,s4)b.shape=(t1,t2,t3,t4)

  • 要使用dot(a,b)你需要
  1. t3=s4 ;
  • 要使用matmul(a,b)你需要
  1. t3=s4
  2. t2=s2 ,或 t2 和 s2 之一为 1
  3. t1=s1 ,或 t1 和 s1 之一为 1

使用以下代码来说服自己。

import numpy as np
for it in xrange(10000):
    a = np.random.rand(5,6,2,4)
    b = np.random.rand(6,4,3)
    c = np.matmul(a,b)
    d = np.dot(a,b)
    #print 'c shape: ', c.shape,'d shape:', d.shape
    
    for i in range(5):
        for j in range(6):
            for k in range(2):
                for l in range(3):
                    if not c[i,j,k,l] == d[i,j,k,j,l]:
                        print it,i,j,k,l,c[i,j,k,l]==d[i,j,k,j,l]  # you will not see them              

这是与np.einsum的比较,以显示如何预测指数

np.allclose(np.einsum('ijk,ijk->ijk', a,b), a*b)        # True 
np.allclose(np.einsum('ijk,ikl->ijl', a,b), a@b)        # True
np.allclose(np.einsum('ijk,lkm->ijlm',a,b), a.dot(b))   # True

我对 MATMUL 和 DOT 的体验

尝试使用 MATMUL 时,我不断收到“ValueError:传递值的形状为 (200, 1),索引暗示 (200, 3)”。 我想要一个快速的解决方法,并发现 DOT 可以提供相同的功能。 使用 DOT 没有任何错误。 我得到正确答案

带 MATMUL

X.shape
>>>(200, 3)

type(X)

>>>pandas.core.frame.DataFrame

w

>>>array([0.37454012, 0.95071431, 0.73199394])

YY = np.matmul(X,w)

>>>  ValueError: Shape of passed values is (200, 1), indices imply (200, 3)"

带点

YY = np.dot(X,w)
# no error message
YY
>>>array([ 2.59206877,  1.06842193,  2.18533396,  2.11366346,  0.28505879, …

YY.shape

>>> (200, )

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM