简体   繁体   English

PyTorch Tensor.index_select() 如何评估张量 output?

[英]How does PyTorch Tensor.index_select() evaluates tensor output?

I am not able to understand how complex indexing - non contiguous indexing of a tensor works.我无法理解索引的复杂性 - 张量的非连续索引是如何工作的。 Here is a sample code and its output这是一个示例代码及其 output

import torch

def describe(x):
  print("Type: {}".format(x.type()))
  print("Shape/size: {}".format(x.shape))
  print("Values: \n{}".format(x))


indices = torch.LongTensor([0,2])
x = torch.arange(6).view(2,3)
describe(torch.index_select(x, dim=1, index=indices))

Returns output as返回 output 为

Type: torch.LongTensor Shape/size: torch.Size([2, 2]) Values: tensor([[0, 2], [3, 5]])类型:torch.LongTensor 形状/大小:torch.Size([2, 2]) 值:tensor([[0, 2], [3, 5]])

Can someone explain how did it arrive to this output tensor?有人能解释一下它是如何到达这个 output 张量的吗? Thanks!谢谢!

You are selecting the first ( indices[0] is 0 ) and third ( indices[1] is 2 ) tensors from x on the first axis ( dim=0 ).您正在从第一个轴( dim=0 )上的x中选择第一个( indices[0]0 )和第三个( indices[1]2 )张量。 Essentially, torch.index_select with dim=1 works the same as doing a direct indexing on the second axis with x[:, indices] .本质上,使用dim=1torch.index_select与使用x[:, indices]在第二个轴上进行直接索引相同。

>>> x
tensor([[0, 1, 2],
        [3, 4, 5]])

So selecting columns (since you're looking at dim=1 and not dim=0 ) which indices are in indices .因此,选择哪些索引在indices中的列(因为您正在查看dim=1而不是dim=0 )。 Imagine having a simple list [0, 2] as indices :想象有一个简单的列表[0, 2]作为indices

>>> indices = [0, 2]

>>> x[:, indices[0]] # same as x[:, 0]
tensor([0, 3])

>>> x[:, indices[1]] # same as x[:, 2]
tensor([2, 5])

So passing the indices as a torch.Tensor allows you to index on all elements of indices directly, ie columns 0 and 2 .因此,将索引作为torch.Tensor允许您直接索引索引的所有元素,即列02 Similar to how NumPy's indexing works.类似于 NumPy 的索引工作方式。

>>> x[:, indices]
tensor([[0, 2],
        [3, 5]])

Here's another example to help you see how it works.这是另一个示例,可帮助您了解其工作原理。 With x defined as x = torch.arange(9).view(3, 3) so we have 3 rows (aka dim=0 ) and 3 columns (aka dim=1 ).x定义为x = torch.arange(9).view(3, 3)所以我们有3行(又名dim=0 )和3列(又名dim=1 )。

>>> indices
tensor([0, 2]) # namely 'first' and 'third'

>>> x = torch.arange(9).view(3, 3)
tensor([[0, 1, 2],
        [3, 4, 5],
        [6, 7, 8]])

>>> x.index_select(0, indices) # select first and third rows
tensor([[0, 1, 2],
        [6, 7, 8]])

>>> x.index_select(1, indices) # select first and third columns
tensor([[0, 2],
        [3, 5],
        [6, 8]])

Note : torch.index_select(x, dim, indices) is equivalent to x.index_select(dim, indices)注意torch.index_select(x, dim, indices)等价于x.index_select(dim, indices)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM