简体   繁体   English

Select 3D Pytorch 张量的特定索引使用表示索引的一维长张量

[英]Select specific indexes of 3D Pytorch Tensor using a 1D long tensor that represents indexes

So I have a tensor that is M x B x C, where M is the number of models, B is the batch and C is the classes and each cell is the probability of a class for a given model and batch. So I have a tensor that is M x B x C, where M is the number of models, B is the batch and C is the classes and each cell is the probability of a class for a given model and batch. Then I have a tensor of the correct answers which is just a 1D of size B we'll call "t".然后我有一个正确答案的张量,它只是大小为 B 的一维,我们称之为“t”。 How do I use the 1D of size B to just return a M x B x 1, where the returned tensor is just the value at the correct class?如何使用大小为 B 的 1D 只返回 M x B x 1,其中返回的张量只是正确 class 的值? Say the M x B x C tensor is called "blah" I've tried说 M x B x C 张量被称为“等等”我试过

blah[:, :, C]

for i in range(M):
    blah[i, :, C]

blah[:, C, :]

The top 2 just return the values of indexes t in the 3rd dimension of every slice.前 2 只返回每个切片的第 3 维中索引 t 的值。 The last one returns the values at t indexes in the 2nd dimension.最后一个返回第二维中 t 个索引处的值。 How do I do this?我该怎么做呢?

We can get the desired result by combining advanced and basic indexing我们可以通过结合高级和基本索引来获得想要的结果

import torch

# shape [2, 3, 4]
blah = torch.tensor([
    [[ 0,  1,  2,  3],
     [ 4,  5,  6,  7],
     [ 8,  9, 10, 11]],
    [[12, 13, 14, 15],
     [16, 17, 18, 19],
     [20, 21, 22, 23]]])

# shape [3]
t = torch.tensor([2, 1, 0])
b = torch.arange(blah.shape[1]).type_as(t)

# shape [2, 3, 1]
result = blah[:, b, t].unsqueeze(-1)

which results in这导致

>>> result
tensor([[[ 2],
         [ 5],
         [ 8]],
        [[14],
         [17],
         [20]]])

You simply need to pass:你只需要通过:

  • your index as the third slice您的索引作为第三个切片
  • range(B) as the second slice range(B)作为第二个切片
    (ie which element in the 2nd dim each 3rd dim index corresponds to) (即每个 3rd dim 索引对应于 2nd dim 中的哪个元素)
blah[:,range(B),t]

Here is one way to do it:这是一种方法:

Suppose a is your M x B x C shaped tensor.假设a是您的M x B x C形张量。 I am taking some representative values below,我在下面取一些有代表性的值,

>>> M = 3
>>> B = 5
>>> C = 4
>>> a = torch.rand(M, B, C)
>>> a
tensor([[[0.6222, 0.6703, 0.0057, 0.3210],
         [0.6251, 0.3286, 0.8451, 0.5978],
         [0.0808, 0.8408, 0.3795, 0.4872],
         [0.8589, 0.8891, 0.8033, 0.8906],
         [0.5620, 0.5275, 0.4272, 0.2286]],

        [[0.2419, 0.0179, 0.2052, 0.6859],
         [0.1868, 0.7766, 0.3648, 0.9697],
         [0.6750, 0.4715, 0.9377, 0.3220],
         [0.0537, 0.1719, 0.0013, 0.0537],
         [0.2681, 0.7514, 0.6523, 0.7703]],

        [[0.5285, 0.5360, 0.7949, 0.6210],
         [0.3066, 0.1138, 0.6412, 0.4724],
         [0.3599, 0.9624, 0.0266, 0.1455],
         [0.7474, 0.2999, 0.7476, 0.2889],
         [0.1779, 0.3515, 0.8900, 0.2301]]])

Let's say the 1D class tensor is t , which gives the true class of each example in the batch.假设一维 class 张量是t ,它给出了批次中每个示例的真实 class 。 So it is a 1D tensor of shape (B, ) having class labels in the range {0, 1, 2, ..., C-1} .所以它是一个形状为(B, )的一维张量,在{0, 1, 2, ..., C-1}范围内具有 class 标签。

>>> t = torch.randint(C, size = (B, ))
>>> t
tensor([3, 2, 1, 1, 0])

So basically you want to select the indices corresponding to t from the innermost dimension of a .所以基本上你想要 select 对应于t从最内层维度的索引a This can be achieved using fancy indexing and broadcasting combined as follows:这可以使用花哨的索引广播来实现,如下所示:

>>> i = torch.arange(M).reshape(M, 1, 1)
>>> j = torch.arange(B).reshape(1, B, 1)
>>> k = t.reshape(1, B, 1)

Note that once you index anything by (i, j, k) , they are going to expand and take the shape (M, B, 1) which is the desired output shape.请注意,一旦您通过(i, j, k)索引任何内容,它们就会扩展并采用所需的 output 形状的形状(M, B, 1) Now just indexing a by i , j and k gives:现在只需通过ijk索引a即可:

>>> a[i, j, k]
tensor([[[0.3210],
         [0.8451],
         [0.8408],
         [0.8891],
         [0.5620]],

        [[0.6859],
         [0.3648],
         [0.4715],
         [0.1719],
         [0.2681]],

        [[0.6210],
         [0.6412],
         [0.9624],
         [0.2999],
         [0.1779]]])

So essentially, if you generate the index arrays conveying your access pattern beforehand, you can directly use them to extract some slice of the tensor.所以本质上,如果你预先生成索引 arrays 来传达你的访问模式,你可以直接使用它们来提取一些张量切片。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM