简体   繁体   中英

PyTorch how to do gathers over multiple dimensions

I'm trying to find a way to do this without for loops.

Say I have a multi-dimensional tensor t0 :

bs = 4
seq = 10
v = 16
t0 = torch.rand((bs, seq, v))

This has shape: torch.Size([4, 10, 16])

I have another tensor labels that is a batch of 5 random indices in the seq dimension:

labels = torch.randint(0, seq, size=[bs, sample])

So this has shape torch.Size([4, 5]) . This is used to index the seq dimension of t0 .

What I want to do is loop over the batch dimension doing gathers using labels tensor. My brute force solution is this:

t1 = torch.empty((bs, sample, v))
for b in range(bs):
    for idx0, idx1 in enumerate(labels[b]):
        t1[b, idx0, :] = t0[b, idx1, :]

Resulting in tensor t1 which has shape: torch.Size([4, 5, 16])

Is there a more idiomatic way of doing this in pytorch?

You can use fancy indexing here to select the desired portion of the tensor.

Essentially, if you generate the index arrays conveying your access pattern beforehand, you can directly use them to extract some slice of the tensor. The shape of the index arrays for each dimension should be same as that of the output tensor or slice you want to extract.

i = torch.arange(bs).reshape(bs, 1, 1) # shape = [bs, 1,      1]
j = labels.reshape(bs, sample, 1)      # shape = [bs, sample, 1]
k = torch.arange(v)                    # shape = [v, ]

# Get result as
t1 = t0[i, j, k]

Note the shapes of the above 3 tensors. Broadcasting appends extra dimesions in the front of a tensor, thus essentially reshaping k to [1, 1, v] shape which makes all 3 of them compatible for elementwise operations.

After broadcasting (i, j, k) together will produce 3 [bs, sample, v] shaped arrays and those will (elementwise) index your original tensor to produce the output tensor t1 of shape [bs, sample, v] .

You could do it like this:

t1 = t0[[[b] for b in range(bs)], labels]

or

t1 = torch.stack([t0[b, labels[b]] for b in range(bs)])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM