简体   繁体   中英

PyTorch: index 2D tensor with 2D tensor of row indices

I have a torch tensor a of shape (x, n) and another tensor b of shape (y, n) where y <= x . every column of b contains a sequence of row indices for a and what I would like to be able to do is to somehow index a with b such that I obtain a tensor of shape (y, n) in which the ith column contains a[:, i][b[:, i]] (not quite sure if that's the correct way to express it).

Here's an example (where x = 5, y = 3 and n = 4):

import torch

a = torch.Tensor(
    [[0.1, 0.2, 0.3, 0.4],
     [0.6, 0.7, 0.8, 0.9],
     [1.1, 1.2, 1.3, 1.4],
     [1.6, 1.7, 1.8, 1.9],
     [2.1, 2.2, 2.3, 2.4]]
)

b = torch.LongTensor(
    [[0, 3, 1, 2],
     [2, 2, 2, 0],
     [1, 1, 0, 4]]
)

# How do I get from a and b to c
# (so that I can also assign to those elements in a)?

c = torch.Tensor(
    [[0.1, 1.7, 0.8, 1.4],
     [1.1, 1.2, 1.3, 0.4],
     [0.6, 0.7, 0.3, 2.4]]
)

I can't get my head around this. What I'm looking for is a method that will not yield the tensor c but also let me assign a tensor of the same shape as c to the elements of a which c is made up of.

I try to use index_select but it supports only 1-dim array for index.

bt = b.transpose(0, 1)
at = a.transpose(0, 1)
ct = [torch.index_select(at[i], dim=0, index=bt[i]) for i in range(len(at))]
c  = torch.stack(ct).transpose(0, 1)
print(c)
"""
tensor([[0.1000, 1.7000, 0.8000, 1.4000],
        [1.1000, 1.2000, 1.3000, 0.4000],
        [0.6000, 0.7000, 0.3000, 2.4000]])
"""

It might be not the best solution, but hope this helps you at least.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM