简体   繁体   English

numpy中的批量矩阵乘法

[英]Batch matrix multiplication in numpy

I have two numpy arrays a and b of shape [5, 5, 5] and [5, 5] , respectively.我有两个形状分别为[5, 5, 5][5, 5] numpy 数组ab For both a and b the first entry in the shape is the batch size.对于ab ,形状中的第一个条目是批量大小。 When I perform matrix multiplication option, I get an array of shape [5, 5, 5] .当我执行矩阵乘法选项时,我得到一个形状为[5, 5, 5]的数组。 An MWE is as follows. MWE如下。

import numpy as np

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = a @ b
# c.shape is (5, 5, 5)

Suppose I were to run a loop over the batch size, ie a[0] @ b[0].T , it would result in an array of shape [5, 1] .假设我要对批量大小运行一个循环,即a[0] @ b[0].T ,它会产生一个形状为[5, 1]的数组。 Finally, if I concatenate all the results along axis 1, I would get a resultant array with shape [5, 5] .最后,如果我沿着轴 1 连接所有结果,我会得到一个形状为[5, 5]的结果数组。 The code below better describes these lines.下面的代码更好地描述了这些行。

a = np.ones((5, 5, 5))
b = np.random.randint(0, 10, (5, 5))
c = []
for i in range(5):
    c.append(a[i] @ b[i].T)
c = np.concatenate([d[:, None] for d in c], axis=1).T
# c.shape evaluates to be (5, 5)

Can I get the above functionality without using loop?我可以在不使用循环的情况下获得上述功能吗? For example, PyTorch provides a function called torch.bmm to compute this.例如,PyTorch 提供了一个名为torch.bmm的函数来计算它。 Thanks.谢谢。

Add an extra dimension to b to make the matrix multiplications batch compatible and remove the redundant last dimension at the end by squeezing :b添加一个额外的维度以使矩阵乘法批处理兼容并通过挤压删除最后一个冗余的维度:

c = np.matmul(a, b[:, :, None]).squeeze(-1)

Or equivalently:或等效地:

c = (a @ b[:, :, None]).squeeze(-1)

Both make the matrix multiplication of a and b appropriate by reshaping b to 5 x 5 x 1 in your example.通过在您的示例中将b重新整形为5 x 5 x 1 ,两者都使ab的矩阵乘法适当。

You can work this out using numpy einsum.你可以使用 numpy einsum 来解决这个问题。

c = np.einsum('BNi,Bi ->BN', a, b)

Pytorch also provides this einsum function with slight change in syntax. Pytorch 还提供了这个 einsum 函数,但语法略有变化。 So you can easily work it out.所以你可以很容易地解决它。 It easily handles other shapes as well.它也可以轻松处理其他形状。

Then you don't have to worry about transpose or squeeze operations.然后您不必担心转置或挤压操作。 It also saves memory because no copy of existing matrices are created internally.它还可以节省内存,因为内部不会创建现有矩阵的副本。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM