[英]How to set value at Tensor index in batch
我有一個批量大小為 N 的張量。
t = [[...], [....], [....] .... ]
在第二個張量索引中,我有 N 個要在每個張量中更改的元素索引
indices = [i0, i1, i2 .... ]
所以我想通過以下方式從t
創建t0
:
t0 = [[ set X at i0 ], [ set X at i1 ], [ set X at i2 ] .... ]
我如何在 Torch 做到這一點?
您似乎正在尋找以下內容:
t[torch.arange(N),indices]
舉個例子:
import torch
a = torch.zeros((3,3))
a[torch.arange(3),[0,2,1]] = 0.2
print(a)
Output:
tensor([[0.2000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.2000],
[0.0000, 0.2000, 0.0000]])
注意:此行為與 NumPy 的integer 數組索引相同
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.