簡體   English   中英

廣播帶有動態形狀的tf.matmul

[英]Broadcast tf.matmul with dynamic shapes

我想在等級2和3的兩個張量之間廣播tf.matmul操作,其中一個張量包含“未知”形狀的維(基本上是特定維中的“無”值)。

問題在於動態尺寸tf.reshapetf.broadcast_to似乎不起作用。

x = tf.placeholder(shape=[None, 5, 10], dtype=tf.float32)
w = tf.ones([10, 20])
y = x @ w
with tf.Session() as sess:
  r1 = sess.run(y, feed_dict={x: np.ones([3, 5, 10])})
  r2 = sess.run(y, feed_dict={x: np.ones([7, 5, 10])})

以上面的代碼為例。 在這種情況下,我要分別喂食兩批分別為3和7的元素。 我希望r1r2是矩陣w與這些批處理中3或7個元素中的每個元素相乘的結果。 因此, r1r2的最終形狀分別為( r2 )和(7、5、20),但是我得到的是:

ValueError:形狀必須為2級,但輸入形狀為[?,5,10],[10,20]的“ matmul”(操作:“ MatMul”)為3級。

w可以擴展為rank-3張量,其批大小等於輸入的大小。 然后,可以執行matmul操作

x = tf.placeholder(shape=[None, 5, 10], dtype=tf.float32)
w = tf.ones([10, 20])

number_batches = tf.shape(x)[0]
w = tf.tile(tf.expand_dims(w, 0), [number_batches, 1, 1])
y = x @ w
with tf.Session() as sess:
  print(sess.run(y, feed_dict={x: np.ones([2, 5, 10])}))
  print(sess.run(y, feed_dict={x: np.ones([3, 5, 10])}))

現場代碼在這里

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM