簡體   English   中英

用 Tensorflow 2.0 中的另一個張量索引張量的第 k 維

[英]Indexing k-th dimension of tensor with another tensor in Tensorflow 2.0

我有一個張量probs ,其形狀(None, None, 110) (batch_size, sequence_length, 110) 我有另一個具有形狀(None, None)的張量indices ,其中包含從probs的第三維到 select 的元素的索引。

我想使用indices來索引張量probs

Numpy 等效:

k, j = np.meshgrid(np.arange(probs.shape[1]), np.arange(probs.shape[0]))
indexed_probs = probs[j, k, indices]

由於probsshape[0]shape[1]未知, tf.meshgrid()不是一個選項。 我找到tf.gathertf.gather_ndtf.batch_gather ,但它們似乎都沒有做我想做的事。

有人知道怎么做這個嗎?

您可以像這樣使用tf.gather_nd做到這一點:

indexed_probs = tf.gather_nd(probs, tf.expand_dims(indices, axis=-1), batch_dims=2)

順便說一句,在 NumPy 你可以使用np.take_along_axis做同樣的事情:

indexed_probs = np.take_along_axis(probs, np.expand_dims(indices, axis=-1), axis=-1)[..., 0]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM