[英]Indexing k-th dimension of tensor with another tensor in Tensorflow 2.0
我有一個張量probs
,其形狀(None, None, 110)
(batch_size, sequence_length, 110)
。 我有另一個具有形狀(None, None)
的張量indices
,其中包含從probs
的第三維到 select 的元素的索引。
我想使用indices
來索引張量probs
。
Numpy 等效:
k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]
由於probs
的shape[0]
和shape[1]
未知, tf.meshgrid()
不是一個選項。 我找到tf.gather
、 tf.gather_nd
和tf.batch_gather
,但它們似乎都沒有做我想做的事。
有人知道怎么做這個嗎?
您可以像這樣使用tf.gather_nd
做到這一點:
indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)
順便說一句,在 NumPy 你可以使用np.take_along_axis
做同樣的事情:
indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.