簡體   English   中英

在不知道批量大小的情況下進行三維批量矩陣乘法

[英]3-D batch matrix multiplication without knowing batch size

我正在編寫一個張量流程序,需要將一批2-D張量(形狀[None,...]的3-D張量)與2-D矩陣W相乘。 這需要將W轉換為3-D矩陣,這需要知道批量大小。

我無法做到這一點; tf.batch_matmul不再可用, x.get_shape().as_list()[0]返回None ,這對於重新整形/平鋪操作無效。 有什么建議? 我見過有些人使用config.cfg.batch_size ,但我不知道那是什么。

解決方案是使用tf.shape在運行時返回形狀)和tf.tile (接受動態形狀)的組合。

x = tf.placeholder(shape=[None, 2, 3], dtype=tf.float32)
W = tf.Variable(initial_value=np.ones([3, 4]), dtype=tf.float32)
print(x.shape)                # Dynamic shape: (?, 2, 3)

batch_size = tf.shape(x)[0]   # A tensor that gets the batch size at runtime
W_expand = tf.expand_dims(W, axis=0)
W_tile = tf.tile(W_expand, multiples=[batch_size, 1, 1])
result = tf.matmul(x, W_tile) # Can multiply now!

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  feed_dict = {x: np.ones([10, 2, 3])}
  print(sess.run(batch_size, feed_dict=feed_dict))    # 10
  print(sess.run(result, feed_dict=feed_dict).shape)  # (10, 2, 4)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM