[英]Tensorflow: Replace indices in 3D tensor using 2D tensor
X
,形狀為[7, 240, 768]
。[7, 240]
的張量mask_idx
,其中包含0/False
和1/True
,其中0/False
表示我不想更新X[i][j]
中的值,而1/True
表示我想要這樣做X[i][j] = tf.zeros([768])
。 我嘗試使用tf.where(mask_idx, tf.zeros([7, 240, 768]), X)
但收到此錯誤:
*** tensorflow.python.framework.errors_impl.InvalidArgumentError:條件 [7,240],然后是 [7,240,768],否則 [7,240,768] 必須是可廣播的 [Op:SelectV2]
任何人都可以提出正確的方法嗎?
TF 檢查維度是否可以從右到左廣播,因此一種簡單的方法是在最后一個維度擴展掩碼張量,即使其形狀為(7,240,1)
。
tf.where(mask_idx[...,None], 0, X)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.