簡體   English   中英

Tensorflow Keras 無作為第一維的張量乘法

[英]Tensorflow Keras Tensor Multiplication with None as First Dimension

我正在使用 TensorFlow Keras 后端,並且我有兩個相同形狀的張量ab(None, 4, 7) ,其中None代表批量維度。

我想做矩陣乘法,我期待(None, 4, 4)的結果。
即對於每批,做一個matmul: (4,7)·(7,4) = (4,4)

這是我的代碼——

K.dot(a, K.reshape(b, (-1, 7, 4)))

此代碼給出形狀張量(None, 4, None, 4)

我想知道高維矩陣乘法是如何工作的? 這樣做的正確方法是什么?

IIUC,您可以直接使用tf.matmul作為 model 的一部分並轉置b或顯式地將操作包裝在Lambda層中:

import tensorflow as tf

a = tf.keras.layers.Input((4, 7))
b = tf.keras.layers.Input((4, 7))
output = tf.matmul(a, b, transpose_b=True)
model = tf.keras.Model([a, b], output)
model.summary()
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_15 (InputLayer)          [(None, 4, 7)]       0           []                               
                                                                                                  
 input_16 (InputLayer)          [(None, 4, 7)]       0           []                               
                                                                                                  
 tf.linalg.matmul_2 (TFOpLambda  (None, 4, 4)        0           ['input_15[0][0]',               
 )                                                                'input_16[0][0]']               
                                                                                                  
==================================================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
__________________________________________________________________________________________________

或者

import tensorflow as tf

a = tf.keras.layers.Input((4, 7))
b = tf.keras.layers.Input((4, 7))
output = tf.keras.layers.Lambda(lambda x: tf.matmul(x[0], x[1], transpose_b=True))([a, b])
model = tf.keras.Model([a, b], output)
model.summary()

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM