简体   繁体   中英

Pytorch gather question (3D Computer Vision)

I have N groups of C-dimension points. In each groups there are M points. So, there is a tensor of (N, M, C). Let's call it features.

I calculated the maximum element and the index through M dimension, to find the maximum points for each C dimension (a max pooling operation), resulting max tensor (N, 1, C) and index tensor (N, 1, C).

I have another tensor of shape (N, M, 3) storing the geometric coordinates of those N*M high-dimention points. Now, I want to use the index from the maximum points in each C dimension, to get the coordinates of all those maximum points.

For example, N=2, M=4, C=6.

The coordinate tensor, whose shape is (2, 4, 3):

[[[1, 2, 3]
  [4, 5, 6]
  [7, 8, 9]
  [8, 7, 6]]

 [11, 12, 13]
 [14, 15, 16]
 [17, 18, 19]
 [18, 17, 16]]]

The indices tensor, whose shape is (2, 1, 6):

[[[0, 1, 2, 1, 2, 3]]
 [[1, 2, 3, 2, 1, 0]]]

For example, the first element in indices is 0, I want to grab [1, 2, 3] from the coordinate tensor out. For the second element (1), I want to grab [4, 5, 6] out. For the third element in the next dimension (3), I want to grab [18, 17, 16] out.

The result tensor will look like:

[[[1, 2, 3]  # 0
  [4, 5, 6]  # 1
  [7, 8, 9]  # 2
  [4, 5, 6]  # 1
  [7, 8, 9]  # 2
  [8, 7, 6]] # 3

 [[14, 15, 16] # 1
  [17, 18, 19] # 2
  [18, 17, 16] # 3
  [17, 18, 19] # 2
  [14, 15, 16] # 1
  [11, 12, 13]]]# 0

whose shape is (2, 6, 3).

I tried to use torch.gather but I can not get it worked. I wrote a naive algorithm enumerating all N groups, but indeed it is slow, even using TorchScript's JIT. So, how to write this efficiently in pytorch?

You can use integer array indexing combined with broadcasting semantics to get your result.

import torch

x = torch.tensor([
    [[1, 2, 3], 
     [4, 5, 6], 
     [7, 8, 9], 
     [8, 7, 6]],
    [[11, 12, 13],
     [14, 15, 16],
     [17, 18, 19],
     [18, 17, 16]],
])

i = torch.tensor([[[0, 1, 2, 1, 2, 3]],
                  [[1, 2, 3, 2, 1, 0]]])

# rows is shape [2, 1], cols is shape [2, 6]
rows = torch.arange(x.shape[0]).type_as(i).unsqueeze(1)
cols = i.squeeze(1)

# y is [2, 6, ...]
y = x[rows, cols]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM