簡體   English   中英

在給定張量索引值的情況下提取張量值

[英]extracting tensor values given tensor index values torch

我有一個形狀為(b,n)的值張量val和一個形狀為(b,m)的索引張量ind (其中n>m )。 我的目標是獲取val中與ind中的索引相對應的值。 我試過使用val[ind] ,但它只擴展了val的維度,而不是只取相關項目

val = torch.tensor([[1,2,3],
                    [4,5,6],
                    [7,8,9],
                    [10,11,12],
                    [13,14,15]])   
ind = torch.tensor([[1,2],
                    [0,2],
                    [0,1],
                    [1,2],
                    [0,1]])
val[ind] # shaped (5,2,4), I need (5,2)

想要的輸出是

torch.tensor([[2,3],
              [4,6],
              [7,8],
              [11,12],
              [13,14]])

您可以使用torch.gather執行此類操作:

>>> val.gather(dim=1, index=ind)
tensor([[ 2,  3],
        [ 4,  6],
        [ 7,  8],
        [11, 12],
        [13, 14]])

基本上使用ind的值索引val的第二維。 返回的張量out如下:

out[i][j] = val[i][ind[i]]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM