简体   繁体   中英

How is a 3-d tensor indexed by two 2d tensors?

Here's a snapshot from line 15-20 in DIM

def random_permute(X):
    X = X.transpose(1, 2)
    b = torch.rand((X.size(0), X.size(1))).cuda()
    idx = b.sort(0)[1]
    adx = torch.range(0, X.size(1) - 1).long()
    X = X[idx, adx[None, :]].transpose(1, 2)

    return X

where X is a tensor of size [64, 64, 128], idx a tensor of size [64, 64], adx a tensor of size [64]. How does X = X[idx, adx[None, :]] work? How can we use two 2d tensors to index a 3d tensor? What really happens to X after this indexing?

From my guess X must be a 3D tensor, since it usually represents a batch of training data.

As far as the functionality of this function is concerned, it randomly permutes the input data tensor X and it does this using the following steps:

  • First it initializes the tensor b with values sampled from uniform distribution.
  • Next this tensor is sorted along dimension 0 and the sorting indices are pulled out to tensor idx .
  • The tensor adx is just an integer tensor of values ranging from 0 to 63.

Now, the below line is where all the magic happens:

X[idx, adx[None, :]].transpose(1, 2)

We use the indices we got before idx and adx ( adx[None, :] is simply a row vector of two dimension). Once we have that, we transpose the axes 1 and 2 exactly like what we did at the beginning of the function in the line:

X = X.transpose(1, 2)

Here is a contrived example, for better understanding:

# our input tensor
In [51]: X = torch.rand(64, 64, 32)

In [52]: X = X.transpose(1, 2)

In [53]: X.shape
Out[53]: torch.Size([64, 32, 64])

In [54]: b = torch.rand((X.size(0), X.size(1)))

# sort `b` which returns a tuple and take only indices
In [55]: idx = b.sort(0)[1]

In [56]: idx.shape
Out[56]: torch.Size([64, 32])

In [57]: adx = torch.arange(0, X.size(1)).long()

In [58]: adx.shape
Out[58]: torch.Size([32])

In [59]: X[idx, adx[None, :]].transpose(1, 2).shape
Out[59]: torch.Size([64, 64, 32])

The important thing to note here is how we got the same shape in the last step as the shape of the input tensor which is (64, 64, 32) .

Things will be more clear if we consider a smaller concrete example. Let

x = np.arange(8).reshape(2, 2, 2)
b = np.random.rand(2, 2)
idx = b.argsort(0) # e.g. idx=[[1, 1], [0, 0]]
adx = np.arange(2)[None, :] # [[0, 1]]
y = x[idx, adx] # implicitly expanding 'adx' to [[0, 1], [0, 1]]

In this example, we'll have y as

y[0, 0] = x[idx[0, 0], adx[0, 0]]=x[1, 0]
y[0, 1] = x[idx[0, 1], adx[0, 1]]=x[1, 1]
y[1, 0] = x[idx[1, 0], adx[1, 0]]=x[0, 0]
...

It may be helpful to see how we do the same in tensorflow:

d0, d1, d2 = x.shape.as_list()
b = np.random.rand(d0, d1)
idx = np.argsort(b, 0)
idx = idx.reshape(-1)
adx = np.arange(0, d1)
adx = np.tile(adx, d0)

y = tf.reshape(tf.gather_nd(x, zip(idx, adx)), (d0, d1, d2))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM