簡體   English   中英

為什么 tf.matmul 不能與轉置張量一起使用?

[英]Why doesn't `tf.matmul` work with transposed tensor?

為什么tf.matmul用於轉置張量?

transpose_b=True是可以的,但不是tf.transpose(inp)

此屏幕截圖是在 Colab 中使用tensorflow-gpu==2.0.0-rc1制作的:

在此處輸入圖像描述

tf.linalg.matmul中的transpose_b=True僅轉置第二個給定張量的最后兩個軸,而tf.transpose沒有更多的 arguments 完全反轉維度。 相當於:

inp_t = tf.transpose(inp, (0, 2, 1))
tf.matmul(inp, inp_t)

如果您沒有明確指定perm (permutation) 參數, tf.transpose()默認執行常規二維矩陣轉置(它將 perm 參數設置為 input_tensor_rank-1)。 所以適當設置perm參數

inp_t = tf.transpose(inp, perm=[0,2,1])
y = tf.matmul(inp, x)
print(y)

Tensorflow 告訴您的是,將兩個張量相乘時尺寸不匹配。 用基本的線性代數術語來考慮它。 矩陣相乘時,只能將矩陣相乘,其中第一個矩陣的最后一維與第二個矩陣的第一維相同。 例如,您可以將 2x4 矩陣與 4x2 矩陣相乘(這是transpose為您所做的。來自文檔

如果未給出perm ,則將其設置為 (n-1...0),其中 n 是輸入張量的秩。 因此,默認情況下,此操作對二維輸入張量執行常規矩陣轉置。

因此,如果您在更高維度上省略 perm, tf.transform()會像 2d 張量(矩陣)一樣切換維度:

inp_t_without_perm = tf.transpose(inp)
inp_t_without_perm
# Output: <tf.Tensor 'transpose_8:0' shape=(1, 4, 2) dtype=float32>

所以它只是切換第一個維度的最后一個維度,而第二個維度保持不變。 這相當於:

inp_t_with_wrong_perm = tf.transpose(inp, perm=[2,1,0])
inp_t_with_wrong_perm
# Output: <tf.Tensor 'transpose_8:0' shape=(1, 4, 2) dtype=float32>

如果你這樣做:

mul = tf.matmul(inp, inp_t_without_perm) # or with inp_t_with_wrong_perm

您收到此錯誤,因為您的前兩個或后兩個維度不匹配。

現在,當將高階張量相乘時,您必須以與 2d 中相同的方式對齊不同的維度(將其視為將張量划分為矩陣和向量。在您的情況下,您有一個向量和一個矩陣...對不起,我還沒想出更好的比喻,當我找到一個安靜的半小時用筆和紙,我可以用愛因斯坦符號使它更正式,但這基本上是它的工作原理...)。

對於您的情況,有效的是:

inp = tf.reshape(tf.linspace(-1.0, 1.0, 8), (2,4,1))
# switch the last two dimensions so you can multiply 4x1 by 1x4
# and leave first dimension as it is.
inp_t = tf.transpose(inp, perm=[0,2,1])
mul = tf.matmul(inp, inp_t)
mul
# Output: <tf.Tensor 'MatMul_8:0' shape=(2, 4, 4) dtype=float32>

請注意,在您的情況下,這是唯一有效的排列,因為這種乘法是不可交換的。 所以你必須從左到右匹配尺寸(再次,抱歉揮手,但正式的數學證明需要我做一些高階張量代數,但我認為這正是你想要實現的...)。 我沒有 go 太深入到文檔中,但我認為transform_b參數正在為你做這個排列。 希望有幫助。 請評論更多問題。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM