簡體   English   中英

Tensorflow:使用 2D 張量替換 3D 張量中的索引

[英]Tensorflow: Replace indices in 3D tensor using 2D tensor

  • 我有一個 3D 張量X ,形狀為[7, 240, 768]
  • 我有另一個形狀為[7, 240]的張量mask_idx ,其中包含0/False1/True ,其中0/False表示我不想更新X[i][j]中的值,而1/True表示我想要這樣做X[i][j] = tf.zeros([768])

我嘗試使用tf.where(mask_idx, tf.zeros([7, 240, 768]), X)但收到此錯誤:

*** tensorflow.python.framework.errors_impl.InvalidArgumentError:條件 [7,240],然后是 [7,240,768],否則 [7,240,768] 必須是可廣播的 [Op:SelectV2]

任何人都可以提出正確的方法嗎?

TF 檢查維度是否可以從右到左廣播,因此一種簡單的方法是在最后一個維度擴展掩碼張量,即使其形狀為(7,240,1)

tf.where(mask_idx[...,None], 0, X)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM