简体   繁体   中英

Advanced indexing in 2d tensor in pytorch

I have a 2d tensor X. and two lists of indexes that is first index and second call a and b. I want to do

X[a[i],b[i]] = 0 for i in range(len(a))

How can I do this. If i directly do X[a,b] the error is IndexError: The advanced indexing objects could not be broadcast

Check your lists which contains the indices, some values might be out of range. That's when you will get IndexError like the one below:

In [43]: X[4,4]

IndexError Traceback (most recent call last) in () ----> 1 X[4,4]

IndexError: index 4 is out of range for dimension 0 (of size 3)

If your indices are in correct range, it should work fine.

Here is an example:

In [35]: X = torch.Tensor([[3, 4, 5, 6], [1, 2, 3, 4], [6, 3, 2, 1]])

In [36]: X
Out[36]: 

 3  4  5  6
 1  2  3  4
 6  3  2  1
[torch.FloatTensor of size 3x4]

In [37]: a = [0, 2]

In [38]: b = [1, 2]

In [39]: X[a, b]
Out[39]: 

 4
 2
[torch.FloatTensor of size 2]

In [40]: X[a, b] = 0

In [41]: X
Out[41]: 

 3  0  5  6
 1  2  3  4
 6  3  0  1
[torch.FloatTensor of size 3x4]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM