![](/img/trans.png)
[英]I want to design a grid with n-number of dimensions ,solution for two dimension is available but i want n-dimensions in Numpy?
[英]Is advanced indexing available across n-dimensions in TensorFlow?
在 PyTorch 中,我們可以使用標准 Pythonic 索引來跨 n 維應用高級索引。
preds
是一個形狀為[1, 3, 64, 64, 12]
的Tensor
。
a
, b
, c
, d
是相同長度的一維Tensor
。 在這種情況下,長度為 9,但並非總是如此。
PyTorch 示例實現了預期的結果:
result = preds[a, b, c, d]
result.shape
>>> [9, 12]
這如何在 TensorFlow 中重現,從相同的 5 個張量開始並創建相同的 output?
我已經嘗試過tf.gather
似乎能夠在單個維度上產生相同的行為:
tf.shape(tf.gather(preds, a))
>>> [9, 3, 64, 64, 12]
是否可以擴展它以最終達到所需的 output 形狀[9, 12]
?
我還注意到tf.gather_nd
的存在,這似乎在這里可能是相關的,但我無法確定如何從文檔中使用它。
是的, gather_nd
可以做到這一點
t = tf.random.uniform(shape=(1,3,64,64,12))
# i_n = indices along n-th dim
i_1 = tf.constant([0,0,0,0,0,0,0,0,0])
i_2 = tf.constant([0,1,2,1,2,2,1,0,0])
i_3 = tf.constant([0,21,15,63,22,17,21,54,39])
i_4 = tf.constant([0,16,26,51,33,45,48,29,1])
i = tf.stack([i_1, i_2, i_3, i_4], axis=1) # i.shape == (9,4)
tf.gather_nd(t, i).shape # (9,12)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.