![](/img/trans.png)
[英]How to change certain values in a torch tensor based on an index in another torch tensor?
[英]extracting tensor values given tensor index values torch
我有一個形狀為(b,n)
的值張量val
和一個形狀為(b,m)
的索引張量ind
(其中n>m
)。 我的目標是獲取val
中與ind
中的索引相對應的值。 我試過使用val[ind]
,但它只擴展了val
的維度,而不是只取相關項目
val = torch.tensor([[1,2,3],
[4,5,6],
[7,8,9],
[10,11,12],
[13,14,15]])
ind = torch.tensor([[1,2],
[0,2],
[0,1],
[1,2],
[0,1]])
val[ind] # shaped (5,2,4), I need (5,2)
想要的輸出是
torch.tensor([[2,3],
[4,6],
[7,8],
[11,12],
[13,14]])
您可以使用torch.gather
執行此類操作:
>>> val.gather(dim=1, index=ind)
tensor([[ 2, 3],
[ 4, 6],
[ 7, 8],
[11, 12],
[13, 14]])
基本上使用ind
的值索引val
的第二維。 返回的張量out
如下:
out[i][j] = val[i][ind[i]]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.