简体   繁体   中英

Indexing a multi-dimensional tensor using only one dimension

I have a PyTorch tensor b with the shape: torch.Size([10, 10, 51]) . I want to select one element between the 10 possible elements in the dimension d=1 (middle one) using a numpy array: a = np.array([0,1,2,3,4,5,6,7,8,9]) . this is just a random example.

I wanted to do: b[:,a,:] but that isn't working

Your solution is likely torch.index_select ( docs )

You'll have to turn a into a tensor first, though.

a_torch = torch.from_numpy(a)
answer = torch.index_select(b, 1, a_torch)

An indexing of b on the second axis using a should do:

>>> b = torch.rand(10, 10, 51)
>>> a = np.array([0,1,2,3,4,5,6,7,8,9])

>>> b[:,  a].shape
torch.Size([10, 10, 51])

I have found the solution on the PyTorch forum: ( https://discuss.pytorch.org/t/how-to-select-specific-vector-in-3d-tensor-beautifully/37724 )

x = torch.tensor([[[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]],
                  [[11, 12, 13],
                   [14, 15, 16],
                   [17, 18, 19]]])

idx = torch.tensor([1, 2])
x[torch.arange(x.size(0)), idx]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM