簡體   English   中英

在 TensorFlow 中是否可以跨 n 維使用高級索引?

[英]Is advanced indexing available across n-dimensions in TensorFlow?

在 PyTorch 中,我們可以使用標准 Pythonic 索引來跨 n 維應用高級索引。

preds是一個形狀為[1, 3, 64, 64, 12]Tensor

a , b , c , d是相同長度的一維Tensor 在這種情況下,長度為 9,但並非總是如此。

PyTorch 示例實現了預期的結果:

result = preds[a, b, c, d]

result.shape
>>> [9, 12]

這如何在 TensorFlow 中重現,從相同的 5 個張量開始並創建相同的 output?

我已經嘗試過tf.gather似乎能夠在單個維度上產生相同的行為:

tf.shape(tf.gather(preds, a))
>>> [9, 3, 64, 64, 12]

是否可以擴展它以最終達到所需的 output 形狀[9, 12]

我還注意到tf.gather_nd的存在,這似乎在這里可能是相關的,但我無法確定如何從文檔中使用它。

是的, gather_nd可以做到這一點

t = tf.random.uniform(shape=(1,3,64,64,12))

# i_n = indices along n-th dim
i_1 = tf.constant([0,0,0,0,0,0,0,0,0])
i_2 = tf.constant([0,1,2,1,2,2,1,0,0])
i_3 = tf.constant([0,21,15,63,22,17,21,54,39])
i_4 = tf.constant([0,16,26,51,33,45,48,29,1])
i = tf.stack([i_1, i_2, i_3, i_4], axis=1)  # i.shape == (9,4)

tf.gather_nd(t, i).shape   # (9,12)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM