I have a PyTorch tensor b
with the shape: torch.Size([10, 10, 51])
. I want to select one element between the 10 possible elements in the dimension d=1 (middle one) using a numpy array: a = np.array([0,1,2,3,4,5,6,7,8,9])
. this is just a random example.
I wanted to do: b[:,a,:]
but that isn't working
Your solution is likely torch.index_select
( docs )
You'll have to turn a
into a tensor first, though.
a_torch = torch.from_numpy(a)
answer = torch.index_select(b, 1, a_torch)
An indexing of b
on the second axis using a
should do:
>>> b = torch.rand(10, 10, 51)
>>> a = np.array([0,1,2,3,4,5,6,7,8,9])
>>> b[:, a].shape
torch.Size([10, 10, 51])
I have found the solution on the PyTorch forum: ( https://discuss.pytorch.org/t/how-to-select-specific-vector-in-3d-tensor-beautifully/37724 )
x = torch.tensor([[[1, 2, 3],
[4, 5, 6],
[7, 8, 9]],
[[11, 12, 13],
[14, 15, 16],
[17, 18, 19]]])
idx = torch.tensor([1, 2])
x[torch.arange(x.size(0)), idx]
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.