[英]How does PyTorch Tensor.index_select() evaluates tensor output?
I am not able to understand how complex indexing - non contiguous indexing of a tensor works.我无法理解索引的复杂性 - 张量的非连续索引是如何工作的。 Here is a sample code and its output
这是一个示例代码及其 output
import torch
def describe(x):
print("Type: {}".format(x.type()))
print("Shape/size: {}".format(x.shape))
print("Values: \n{}".format(x))
indices = torch.LongTensor([0,2])
x = torch.arange(6).view(2,3)
describe(torch.index_select(x, dim=1, index=indices))
Returns output as返回 output 为
Type: torch.LongTensor Shape/size: torch.Size([2, 2]) Values: tensor([[0, 2], [3, 5]])
类型:torch.LongTensor 形状/大小:torch.Size([2, 2]) 值:tensor([[0, 2], [3, 5]])
Can someone explain how did it arrive to this output tensor?有人能解释一下它是如何到达这个 output 张量的吗? Thanks!
谢谢!
You are selecting the first ( indices[0]
is 0
) and third ( indices[1]
is 2
) tensors from x
on the first axis ( dim=0
).您正在从第一个轴(
dim=0
)上的x
中选择第一个( indices[0]
为0
)和第三个( indices[1]
为2
)张量。 Essentially, torch.index_select
with dim=1
works the same as doing a direct indexing on the second axis with x[:, indices]
.本质上,使用
dim=1
的torch.index_select
与使用x[:, indices]
在第二个轴上进行直接索引相同。
>>> x
tensor([[0, 1, 2],
[3, 4, 5]])
So selecting columns (since you're looking at dim=1
and not dim=0
) which indices are in indices
.因此,选择哪些索引在
indices
中的列(因为您正在查看dim=1
而不是dim=0
)。 Imagine having a simple list [0, 2]
as indices
:想象有一个简单的列表
[0, 2]
作为indices
:
>>> indices = [0, 2]
>>> x[:, indices[0]] # same as x[:, 0]
tensor([0, 3])
>>> x[:, indices[1]] # same as x[:, 2]
tensor([2, 5])
So passing the indices as a torch.Tensor
allows you to index on all elements of indices directly, ie columns 0
and 2
.因此,将索引作为
torch.Tensor
允许您直接索引索引的所有元素,即列0
和2
。 Similar to how NumPy's indexing works.类似于 NumPy 的索引工作方式。
>>> x[:, indices]
tensor([[0, 2],
[3, 5]])
Here's another example to help you see how it works.这是另一个示例,可帮助您了解其工作原理。 With
x
defined as x = torch.arange(9).view(3, 3)
so we have 3 rows (aka dim=0
) and 3 columns (aka dim=1
).将
x
定义为x = torch.arange(9).view(3, 3)
所以我们有3行(又名dim=0
)和3列(又名dim=1
)。
>>> indices
tensor([0, 2]) # namely 'first' and 'third'
>>> x = torch.arange(9).view(3, 3)
tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
>>> x.index_select(0, indices) # select first and third rows
tensor([[0, 1, 2],
[6, 7, 8]])
>>> x.index_select(1, indices) # select first and third columns
tensor([[0, 2],
[3, 5],
[6, 8]])
Note : torch.index_select(x, dim, indices)
is equivalent to x.index_select(dim, indices)
注意:
torch.index_select(x, dim, indices)
等价于x.index_select(dim, indices)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.