简体   繁体   English

如何在 PyTorch 中乘以矩阵?

[英]How do I multiply matrices in PyTorch?

With numpy, I can do a simple matrix multiplication like this:使用 numpy,我可以像这样进行简单的矩阵乘法:

a = numpy.ones((3, 2))
b = numpy.ones((2, 1))
result = a.dot(b)

However, this does not work with PyTorch:但是,这不适用于 PyTorch:

a = torch.ones((3, 2))
b = torch.ones((2, 1))
result = torch.dot(a, b)

This code throws the following error:此代码引发以下错误:

RuntimeError: 1D tensors expected, but got 2D and 2D tensors RuntimeError: 预期 1D 张量,但得到 2D 和 2D 张量

How do I perform matrix multiplication in PyTorch?如何在 PyTorch 中执行矩阵乘法?

Use torch.mm :使用torch.mm

torch.mm(a, b)

torch.dot() behaves differently to np.dot() . torch.dot()的行为与np.dot()不同。 There's been some discussion about what would be desirable here .有一些关于这里需要什么的讨论。 Specifically, torch.dot() treats both a and b as 1D vectors (irrespective of their original shape) and computes their inner product.具体来说, torch.dot()ab都视为一维向量(无论它们的原始形状如何)并计算它们的内积。 The error is thrown because this behaviour makes your a a vector of length 6 and your b a vector of length 2;抛出错误是因为此行为使您a a 成为长度为 6 的向量,而您的b成为长度为 2 的向量; hence their inner product can't be computed.因此无法计算它们的内积。 For matrix multiplication in PyTorch, use torch.mm() .对于 PyTorch 中的矩阵乘法,请使用torch.mm() Numpy's np.dot() in contrast is more flexible; Numpy 的np.dot()相比之下更加灵活; it computes the inner product for 1D arrays and performs matrix multiplication for 2D arrays.它计算一维数组的内积并为二维数组执行矩阵乘法。

torch.matmul performs matrix multiplications if both arguments are 2D and computes their dot product if both arguments are 1D . torch.matmul如果两个参数都是2D则执行矩阵乘法,如果两个参数都是1D则计算它们的点积。 For inputs of such dimensions, its behaviour is the same as np.dot .对于此类维度的输入,其行为与np.dot相同。 It also lets you do broadcasting or matrix x matrix , matrix x vector and vector x vector operations in batches.它还允许您批量进行广播或matrix x matrixmatrix x vectorvector x vector操作。

# 1D inputs, same as torch.dot
a = torch.rand(n)
b = torch.rand(n)
torch.matmul(a, b) # torch.Size([])

# 2D inputs, same as torch.mm
a = torch.rand(m, k)
b = torch.rand(k, j)
torch.matmul(a, b) # torch.Size([m, j])

To perform a matrix (rank 2 tensor) multiplication, use any of the following equivalent ways:要执行矩阵(2 阶张量)乘法,请使用以下任何等效方法:

AB = A.mm(B)

AB = torch.mm(A, B)

AB = torch.matmul(A, B)

AB = A @ B  # Python 3.5+ only

There are a few subtleties.有一些微妙之处。 From the PyTorch documentation :来自PyTorch 文档

torch.mm does not broadcast. torch.mm不广播。 For broadcasting matrix products, see torch.matmul() .对于广播矩阵产品,请参阅torch.matmul()

For instance, you cannot multiply two 1-dimensional vectors with torch.mm , nor multiply batched matrices (rank 3).例如,您不能将两个一维向量与torch.mm相乘,也不能将批处理矩阵相乘(等级 3)。 To this end, you should use the more versatile torch.matmul .为此,您应该使用更通用的torch.matmul For an extensive list of the broadcasting behaviours of torch.matmul , see the documentation .有关torch.matmul的广播行为的详细列表,请参阅文档

For element-wise multiplication, you can simply do (if A and B have the same shape)对于逐元素乘法,您可以简单地执行(如果 A 和 B 具有相同的形状)

A * B  # element-wise matrix multiplication (Hadamard product)

Use torch.mm(a, b) or torch.matmul(a, b)使用torch.mm(a, b)torch.matmul(a, b)
Both are same.两者都是一样的。

>>> torch.mm
<built-in method mm of type object at 0x11712a870>
>>> torch.matmul
<built-in method matmul of type object at 0x11712a870>

There's one more option that may be good to know.还有一个可能很高兴知道的选择。 That is @ operator.那是@运算符。 @Simon H. @西蒙H。

>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> a@b
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.mm(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.matmul(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])    

The three give the same results.三者给出相同的结果。

Related links:相关链接:
Matrix multiplication operator矩阵乘法运算符
PEP 465 -- A dedicated infix operator for matrix multiplication PEP 465——矩阵乘法的专用中缀运算符

You can use "@" for computing a dot product between two tensors in pytorch.您可以使用“@”来计算 pytorch 中两个张量之间的点积。

a = torch.tensor([[1,2],
                  [3,4]])
b = torch.tensor([[5,6],
                  [7,8]])
c = a@b #For dot product
c

d = a*b #For elementwise multiplication 
d

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM