簡體   English   中英

如何在張量流中用 3d 張量對 2d 張量進行 matmul?

[英]How to matmul a 2d tensor with a 3d tensor in tensorflow?

numpy您可以將 2d 數組與 3d 數組相乘,如下例所示:

>>> X = np.random.randn(3,5,4) # [3,5,4]
... W = np.random.randn(5,5) # [5,5]
... out = np.matmul(W, X) # [3,5,4]

根據我的理解, np.matmul()接受W並沿X的第一維廣播它。 但在tensorflow是不允許的:

>>> _X = tf.constant(X)
... _W = tf.constant(W)
... _out = tf.matmul(_W, _X)

ValueError: Shape must be rank 2 but is rank 3 for 'MatMul_1' (op: 'MatMul') with input shapes: [5,5], [3,5,4].

那么上面的tensorflow np.matmul()是否有等價物? 張量tensorflow 2d 張量與 3d 張量相乘的最佳實踐是什么?

嘗試使用tf.tile在乘法之前匹配矩陣的維度。 numpy 的自動廣播功能似乎沒有在 tensorflow 中實現。 你必須手動完成。

W_T = tf.tile(tf.expand_dims(W,0),[3,1,1])

這應該可以解決問題

import numpy as np
import tensorflow as tf

X = np.random.randn(3,4,5)
W = np.random.randn(5,5)

_X = tf.constant(X)
_W = tf.constant(W)
_W_t = tf.tile(tf.expand_dims(_W,0),[3,1,1])

with tf.Session() as sess:
    print(sess.run(tf.matmul(_X,_W_t)))

您可以改用tensordot

tf.transpose(tf.tensordot(_W, _X, axes=[[1],[1]]),[1,0,2])

以下來自 tensorflow XLA廣播語義

XLA 語言盡可能嚴格和明確,避免隱含和“神奇”的特征。 這些特性可能會使一些計算更容易定義,但代價是在用戶代碼中加入了更多的假設,這些假設在長期內很難改變。

所以 Tensorflow 不提供內置的廣播功能。

然而,它確實提供了一些可以像廣播一樣重塑張量的東西。 這個操作叫做tf.tile

簽名如下:

tf.tile(input, multiples, name=None)

此操作通過多次復制輸入來創建新的張量。 輸出張量的第 i 個維度具有 input.dims(i) * multiples[i] 元素,並且輸入的值沿第 i 個維度復制 multiples[i] 次。

您還可以使用tf.einsum來避免平鋪張量:

tf.einsum("ab,ibc->iac", _W, _X)

一個完整的例子:

import numpy as np
import tensorflow as tf

# Numpy-style matrix multiplication:
X = np.random.randn(3,5,4)
W = np.random.randn(5,5)
np_WX = np.matmul(W, X)

# TensorFlow-style multiplication:
_X = tf.constant(X)
_W = tf.constant(W)
_WX = tf.einsum("ab,ibc->iac", _W, _X)

with tf.Session() as sess:
    tf_WX = sess.run(_WX)

# Check that the results are the same:
print(np.allclose(np_WX, tf_WX))

在這里,我將使用 keras 后端K.dot和 tensorflow tf.transpose 首先交換 3D 張量的內部暗淡

X=tf.transpose(X,perm=[0,-1,1]) # X shape=[3,4,5]

現在像這樣繁殖

out=K.dot(X,W) # out shape=[3,4,5]

現在再次交換軸

out = tf.transpose(out,perm=[0,-1,1]) # out shape=[3,5,4]

上述解決方案以很少的時間成本節省了內存,因為您沒有平鋪W

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM