簡體   English   中英

我如何通過索引從 pytorch 張量中獲取 select 值?

[英]How do I select values from a pytorch tensor by index?

我正在努力弄清楚如何使用 pytorch 解決 go 這個 class 問題。問題是“為所有 i,j 選擇值 x[i,j,k],其中 ind[i,j] = k 在張量中,
張量的形狀應為 (10,50)"

ind = torch.randint(50,(10,50))
x = torch.randn(10,50,50)

我可以使用torch.scatter.gather來做到這一點嗎?

你可以使用torch.gather ,你只需要擴展你的索引的暗淡:

y = torch.gather(x,2,ind[:,:,None]).squeeze(2)
assert y[0] == x[0,0,ind[0][0]]

這是因為索引必須與輸入張量具有相同的維度。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM