[英]How do I select values from a pytorch tensor by index?
我正在努力弄清楚如何使用 pytorch 解決 go 這個 class 問題。問題是“為所有 i,j 選擇值 x[i,j,k],其中 ind[i,j] = k 在張量中,
張量的形狀應為 (10,50)"
ind = torch.randint(50,(10,50))
x = torch.randn(10,50,50)
我可以使用torch.scatter
或.gather
來做到這一點嗎?
你可以使用torch.gather
,你只需要擴展你的索引的暗淡:
y = torch.gather(x,2,ind[:,:,None]).squeeze(2)
assert y[0] == x[0,0,ind[0][0]]
這是因為索引必須與輸入張量具有相同的維度。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.