简体   繁体   中英

How to set value at Tensor index in batch

I have a tensor of batch size N.

t = [[...], [....], [....] .... ]

In second tensor indices, I have N indices of elements I want to change in each tensor

indices = [i0, i1, i2 .... ]

So I want to have t0 created from t via:

t0 = [[ set X at i0 ], [ set X at i1 ], [ set X at i2 ] .... ]

How can I do this at Torch?

It seems like you're looking for the following:

t[torch.arange(N),indices]

As an example:

import torch
a = torch.zeros((3,3))
a[torch.arange(3),[0,2,1]] = 0.2
print(a)

Output:

tensor([[0.2000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.2000],
        [0.0000, 0.2000, 0.0000]])

Note: This behavior is the same as NumPy's integer array indexing

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM