简体   繁体   English

为什么tf.matmul(a,b,transpose_b = True)有效,但不是tf.matmul(a,tf.transpose(b))?

[英]Why does tf.matmul(a,b, transpose_b=True) work, but not tf.matmul(a, tf.transpose(b))?

Code: 码:

x = tf.constant([1.,2.,3.], shape = (3,2,4))
y = tf.constant([1.,2.,3.], shape = (3,21,4))
tf.matmul(x,y)                     # Doesn't work. 
tf.matmul(x,y,transpose_b = True)  # This works. Shape is (3,2,21)
tf.matmul(x,tf.transpose(y))       # Doesn't work.

I want to know what shape y becomes inside tf.matmul(x,y,transpose_b = True) so I can work out what is really going on inside an LSTM with attention. 我想知道ytf.matmul(x,y,transpose_b = True)里面变成了什么样的形状tf.matmul(x,y,transpose_b = True)所以我可以注意到LSTM里面真正发生的事情。

Transpose can be defined differently for tensors of rank > 2, and here the difference is in axes that are transposed by tf.transpose and tf.matmul(..., transpose_b=True) . 对于秩> 2的张量,可以不同地定义转置,并且这里的差异在于由tf.transposetf.matmul(..., transpose_b=True)转置的轴。

By default, tf.transpose does this: 默认情况下, tf.transpose执行此操作:

The returned tensor's dimension i will correspond to the input dimension perm[i] . 返回的张量的维度i将对应于输入维度perm[i] If perm is not given, it is set to (n-1...0) , where n is the rank of the input tensor. 如果没有给出perm,则将其设置为(n-1...0) ,其中n是输入张量的等级。 Hence by default, this operation performs a regular matrix transpose on 2-D input Tensors. 因此,默认情况下,此操作在2-D输入张量上执行常规矩阵转置。

So in your case, it's going to transform y into a tensor of shape (4, 21, 3) , which is not compatible with x (see below). 所以在你的情况下,它会将y转换为一个形状的张量(4, 21, 3) 4,21,3 (4, 21, 3) ,这与x 不兼容 (见下文)。

But if you set perm=[0, 2, 1] , the result is compatible : 但是如果设置perm=[0, 2, 1] ,结果是兼容的

# Works! (3, 2, 4) * (3, 4, 21) -> (3, 2, 21).
tf.matmul(x, tf.transpose(y, [0, 2, 1]))

About tf.matmul 关于tf.matmul

You can compute the dot product: (a, b, c) * (a, c, d) -> (a, b, d) . 您可以计算点积: (a, b, c) * (a, c, d) -> (a, b, d) But it's not tensor dot product -- it's a batch operation (see this question ). 但它不是张量点产品 - 它是批量操作 (见这个问题 )。

In this case, a is considered a batch size, so tf.matmul computes a dot-products of matrices (b, c) * (c, d) . 在这种情况下, a被认为是批量大小,因此tf.matmul计算矩阵(b, c) * (c, d) a点积。

Batch can be more than one dimension, so this is also valid: 批处理可以是多个维度,因此这也是有效的:

(a, b, c, d) * (a, b, d, e) -> (a, b, c, e)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM