[英]How does torch.einsum perform this 4D tensor multiplication?
I have come across a code which uses torch.einsum
to compute a tensor multiplication.我遇到过使用
torch.einsum
计算张量乘法的代码。 I am able to understand the workings for lower order tensors , but, not for the 4D tensor as below:我能够理解低阶张量的工作原理,但不适用于 4D 张量,如下所示:
import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])
I need help regarding:我需要以下方面的帮助:
torch.einsum
actually beneficial in this scenario? torch.einsum
在这种情况下真的有用吗? (Skip to the tl;dr section if you just want the breakdown of steps involved in an einsum) (如果您只想分解 einsum 中涉及的步骤,请跳至 tl;dr 部分)
I'll try to explain how einsum
works step by step for this example but instead of using torch.einsum
, I'll be using numpy.einsum
( documentation ), which does exactly the same but I am just, in general, more comfortable with it.我将尝试针对此示例逐步解释
einsum
的工作原理,但我将使用numpy.einsum
( 文档)而不是使用torch.einsum
,它的作用完全相同,但总的来说,我更舒服用它。 N.netheless, the same steps happen for torch as well. N.尽管如此,同样的步骤也发生在 torch 上。
Let's rewrite the above code in NumPy -让我们在 NumPy 中重写上面的代码——
import numpy as np
a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape
#(3, 2, 5, 4)
Einsum is composed of 3 steps: multiply
, sum
and transpose
Einsum 由 3 个步骤组成:
multiply
, sum
和transpose
Let's look at our dimensions.让我们看看我们的维度。 We have a
(3, 5, 2, 10)
and a (3, 4, 2, 10)
that we need to bring to (3, 2, 5, 4)
based on 'nxhd,nyhd->nhxy'
我们有一个
(3, 5, 2, 10)
和一个(3, 4, 2, 10)
我们需要根据'nxhd,nyhd->nhxy'
带到(3, 2, 5, 4)
Let's not worry about the order in which the n,x,y,h,d
axes is, and just worry about the fact if you want to keep them or remove (reduce) them.让我们不要担心
n,x,y,h,d
轴的顺序,如果您想保留它们或删除(减少)它们,只需担心这个事实。 Writing them down as a table and see how we can arrange our dimensions -将它们写成表格,看看我们如何安排尺寸 -
## Multiply ##
n x y h d
--------------------
a -> 3 5 2 10
b -> 3 4 2 10
c1 -> 3 5 4 2 10
To get the broadcasting multiplication between x
and y
axis to result in (x, y)
, we will have to add a new axis at the right places and then multiply.为了使
x
轴和y
轴之间的广播乘法得到(x, y)
,我们必须在正确的位置添加一个新轴然后相乘。
a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)
c1 = a1*b1
c1.shape
#(3, 5, 4, 2, 10) #<-- (n, x, y, h, d)
Next, we want to reduce the last axis 10. This will get us the dimensions (n,x,y,h)
.接下来,我们要减少最后一个轴 10。这将为我们提供维度
(n,x,y,h)
。
## Reduce ##
n x y h d
--------------------
c1 -> 3 5 4 2 10
c2 -> 3 5 4 2
This is straightforward.这很简单。 Lets just do
np.sum
over the axis=-1
让我们在轴上做
np.sum
axis=-1
c2 = np.sum(c1, axis=-1)
c2.shape
#(3,5,4,2) #<-- (n, x, y, h)
The last step is rearranging the axis using a transpose.最后一步是使用转置重新排列轴。 We can use
np.transpose
for this.我们可以为此使用
np.transpose
。 np.transpose(0,3,1,2)
basically brings the 3rd axis after the 0th axis and pushes the 1st and 2nd. np.transpose(0,3,1,2)
基本上将第 3 轴带到第 0 轴之后并推动第 1 和第 2 轴。 So, (n,x,y,h)
becomes (n,h,x,y)
所以,
(n,x,y,h)
变成(n,h,x,y)
c3 = c2.transpose(0,3,1,2)
c3.shape
#(3,2,5,4) #<-- (n, h, x, y)
Let's do a final check and see if c3 is the same as the c which was generated from the np.einsum
-让我们做最后的检查,看看 c3 是否与从 np.einsum 生成的
np.einsum
-
np.allclose(c,c3)
#True
Thus, we have implemented the 'nxhd, nyhd -> nhxy'
as -因此,我们已经实现了
'nxhd, nyhd -> nhxy'
作为 -
input -> nxhd, nyhd
multiply -> nxyhd #broadcasting
sum -> nxyh #reduce
transpose -> nhxy
Advantage of np.einsum
over the multiple steps taken, is that you can choose the "path" that it takes to do the computation and perform multiple operations with the same function. This can be done by optimize
paramter, which will optimize the contraction order of an einsum expression. np.einsum
相对于采取的多个步骤的优势在于,您可以选择进行计算所需的“路径”并使用相同的 function 执行多个操作。这可以通过optimize
参数来完成,这将优化收缩顺序一个 einsum 表达式。
A non-exhaustive list of these operations, which can be computed by einsum
, is shown below along with examples:这些操作的非详尽列表,可以通过
einsum
计算,如下所示,并附有示例:
numpy.trace
.numpy.trace
。numpy.diag
.numpy.diag
。numpy.sum
.numpy.sum
。numpy.transpose
.numpy.transpose
。numpy.matmul
numpy.dot
.numpy.matmul
numpy.dot
。numpy.inner
numpy.outer
.numpy.inner
numpy.outer
。numpy.multiply
.numpy.multiply
。numpy.tensordot
.numpy.tensordot
。numpy.einsum_path
.numpy.einsum_path
。%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)
It shows that np.einsum
does the operation faster than individual steps.它表明
np.einsum
执行操作比单个步骤更快。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.