简体   繁体   English

torch.einsum 如何执行这个 4D 张量乘法?

[英]How does torch.einsum perform this 4D tensor multiplication?

I have come across a code which uses torch.einsum to compute a tensor multiplication.我遇到过使用torch.einsum计算张量乘法的代码。 I am able to understand the workings for lower order tensors , but, not for the 4D tensor as below:我能够理解低阶张量的工作原理,但不适用于 4D 张量,如下所示:

import torch

a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))

c = torch.einsum('nxhd,nyhd->nhxy', [a,b])

print(c.size())

# output: torch.Size([3, 2, 5, 4])

I need help regarding:我需要以下方面的帮助:

  1. What is the operation that has been performed here (explanation for how the matrices were multiplied/transposed etc.)?这里执行的操作是什么(解释矩阵如何相乘/转置等)?
  2. Is torch.einsum actually beneficial in this scenario? torch.einsum在这种情况下真的有用吗?

(Skip to the tl;dr section if you just want the breakdown of steps involved in an einsum) (如果您只想分解 einsum 中涉及的步骤,请跳至 tl;dr 部分)

I'll try to explain how einsum works step by step for this example but instead of using torch.einsum , I'll be using numpy.einsum ( documentation ), which does exactly the same but I am just, in general, more comfortable with it.我将尝试针对此示例逐步解释einsum的工作原理,但我将使用numpy.einsum文档)而不是使用torch.einsum ,它的作用完全相同,但总的来说,我更舒服用它。 N.netheless, the same steps happen for torch as well. N.尽管如此,同样的步骤也发生在 torch 上。

Let's rewrite the above code in NumPy -让我们在 NumPy 中重写上面的代码——

import numpy as np

a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape

#(3, 2, 5, 4)

Step by step np.einsum一步一步 np.einsum

Einsum is composed of 3 steps: multiply , sum and transpose Einsum 由 3 个步骤组成: multiplysumtranspose

Let's look at our dimensions.让我们看看我们的维度。 We have a (3, 5, 2, 10) and a (3, 4, 2, 10) that we need to bring to (3, 2, 5, 4) based on 'nxhd,nyhd->nhxy'我们有一个(3, 5, 2, 10)和一个(3, 4, 2, 10)我们需要根据'nxhd,nyhd->nhxy'带到(3, 2, 5, 4)

1. Multiply 1.相乘

Let's not worry about the order in which the n,x,y,h,d axes is, and just worry about the fact if you want to keep them or remove (reduce) them.让我们不要担心n,x,y,h,d轴的顺序,如果您想保留它们或删除(减少)它们,只需担心这个事实。 Writing them down as a table and see how we can arrange our dimensions -将它们写成表格,看看我们如何安排尺寸 -

        ## Multiply ##
       n   x   y   h   d
      --------------------
a  ->  3   5       2   10
b  ->  3       4   2   10
c1 ->  3   5   4   2   10

To get the broadcasting multiplication between x and y axis to result in (x, y) , we will have to add a new axis at the right places and then multiply.为了使x轴和y轴之间的广播乘法得到(x, y) ,我们必须在正确的位置添加一个新轴然后相乘。

a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)

c1 = a1*b1
c1.shape

#(3, 5, 4, 2, 10)  #<-- (n, x, y, h, d)

2. Sum / Reduce 2. 总和/减少

Next, we want to reduce the last axis 10. This will get us the dimensions (n,x,y,h) .接下来,我们要减少最后一个轴 10。这将为我们提供维度(n,x,y,h)

          ## Reduce ##
        n   x   y   h   d
       --------------------
c1  ->  3   5   4   2   10
c2  ->  3   5   4   2

This is straightforward.这很简单。 Lets just do np.sum over the axis=-1让我们在轴上做np.sum axis=-1

c2 = np.sum(c1, axis=-1)
c2.shape

#(3,5,4,2)  #<-- (n, x, y, h)

3. Transpose 3.转置

The last step is rearranging the axis using a transpose.最后一步是使用转置重新排列轴。 We can use np.transpose for this.我们可以为此使用np.transpose np.transpose(0,3,1,2) basically brings the 3rd axis after the 0th axis and pushes the 1st and 2nd. np.transpose(0,3,1,2)基本上将第 3 轴带到第 0 轴之后并推动第 1 和第 2 轴。 So, (n,x,y,h) becomes (n,h,x,y)所以, (n,x,y,h)变成(n,h,x,y)

c3 = c2.transpose(0,3,1,2)
c3.shape

#(3,2,5,4)  #<-- (n, h, x, y)

4. Final check 4. 最终检查

Let's do a final check and see if c3 is the same as the c which was generated from the np.einsum -让我们做最后的检查,看看 c3 是否与从 np.einsum 生成的np.einsum -

np.allclose(c,c3)

#True

TL;DR.长话短说;博士。

Thus, we have implemented the 'nxhd, nyhd -> nhxy' as -因此,我们已经实现了'nxhd, nyhd -> nhxy'作为 -

input     -> nxhd, nyhd
multiply  -> nxyhd      #broadcasting
sum       -> nxyh       #reduce
transpose -> nhxy

Advantage优势

Advantage of np.einsum over the multiple steps taken, is that you can choose the "path" that it takes to do the computation and perform multiple operations with the same function. This can be done by optimize paramter, which will optimize the contraction order of an einsum expression. np.einsum相对于采取的多个步骤的优势在于,您可以选择进行计算所需的“路径”并使用相同的 function 执行多个操作。这可以通过optimize参数来完成,这将优化收缩顺序一个 einsum 表达式。

A non-exhaustive list of these operations, which can be computed by einsum , is shown below along with examples:这些操作的非详尽列表,可以通过einsum计算,如下所示,并附有示例:

  • Trace of an array, numpy.trace .数组的跟踪, numpy.trace
  • Return a diagonal, numpy.diag .返回对角线numpy.diag
  • Array axis summations, numpy.sum .数组轴求和, numpy.sum
  • Transpositions and permutations, numpy.transpose .换位和排列, numpy.transpose
  • Matrix multiplication and dot product, numpy.matmul numpy.dot .矩阵乘法和点积, numpy.matmul numpy.dot
  • Vector inner and outer products, numpy.inner numpy.outer .向量内积和外积, numpy.inner numpy.outer
  • Broadcasting, element-wise and scalar multiplication, numpy.multiply .广播、逐元素和标量乘法, numpy.multiply
  • Tensor contractions, numpy.tensordot .张量收缩, numpy.tensordot
  • Chained array operations, inefficient calculation order, numpy.einsum_path .链式数组操作,低效的计算顺序, numpy.einsum_path

Benchmarks基准

%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)

It shows that np.einsum does the operation faster than individual steps.它表明np.einsum执行操作比单个步骤更快。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM