簡體   English   中英

具有可變長度數組的索引多維火炬張量

[英]Index multidimensional torch tensor with array of variable length

我有一個索引列表和一個形狀的張量:

shape = [batch_size, d_0, d_1, ..., d_k]
idx = [i_0, i_1, ..., i_k]

有沒有辦法用索引i_0, ..., i_k有效地索引每個暗淡d_0, ..., d_k上的張量? k僅在運行時可用)

結果應該是:

tensor[:, i_0, i_1, ..., i_k] #tensor.shape = [batch_size]

目前我正在創建一個切片元組,每個維度一個:

idx = (slice(tensor.shape[0]),) + tuple(slice(i, i+1) for i in idx)
tensor[idx]

但我更喜歡類似的東西:

tensor[:, *idx]

例子:

a = torch.randint(0,10,[3,3,3,3])
indexes = torch.LongTensor([1,1,1])

我只想索引最后的 len(indexes) 維度,例如:

a[:, indexes[0], indexes[1], indexes[2]]

但在一般情況下,我不知道indexes有多長。


注意:這個答案沒有幫助,因為它索引了所有維度,並且不適用於適當的子集!

不幸的是,您不能為索引提供1個切片和迭代器的組合(例如a[:,*idx] )。 但是,您可以通過將其包裹在括號中以強制轉換為迭代器來實現幾乎相同的效果:

a[(slice(None), *idx)]

  1. 在 Python 中, x[(exp1, exp2, ..., expN)]等價於x[exp1, exp2, ..., expN] 后者只是前者的語法糖。

    https://numpy.org/doc/stable/reference/arrays.indexing.html

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM