繁体   English   中英

用另一个多维张量索引多维火炬张量

[英]Index multidimensional torch tensor by another multidimensional tensor

我在 pytorch 中有一个张量x假设形状 (5,3,2,6) 和另一个形状 (5,3,2,1) 的张量idx ,其中包含第一个张量中每个元素的索引。 我想用第二个张量的索引对第一个张量进行切片。 我尝试了 x= x[idx] 但是当我真的希望它的形状为 (5,3,2) 或 (5,3,2,1) 时,我得到了一个奇怪的维度。

我将尝试举一个更简单的例子:假设

x=torch.Tensor([[10,20,30],
                 [8,4,43]])
idx = torch.Tensor([[0],
                    [2]])

我想要类似的东西

y = x[idx]

这样 'y' 输出[[10],[43]]或类似的东西。

索引代表最后一维所需元素的 position。 对于上面的示例,其中 x.shape = (2,3) 最后一个维度是列,那么 'idx' 中的索引是列。 我想要这个,但超过 2 个维度

根据我从评论中了解到的情况,您需要idx作为最后一个维度中的索引,并且idx中的每个索引对应于x中的相似索引(最后一个维度除外)。 在这种情况下(这是 numpy 版本,您可以将其转换为手电筒):

ind = np.indices(idx.shape)
ind[-1] = idx
x[tuple(ind)]

output:

[[10]
 [43]]

您可以使用range squeeze以获得适当的idx尺寸,例如

x[range(x.size(0)), idx.squeeze()]
tensor([10., 43.])

# or
x[range(x.size(0)), idx.squeeze()].unsqueeze(1)
tensor([[10.],
        [43.]])

这是使用 collect 在gather中工作的那个。 idx需要采用以下行将确保的torch.int64格式(注意tensor中 't' 的小写字母)。

idx = torch.tensor([[0],
                    [2]])
torch.gather(x, 1, idx) # 1 is the axis to index here
tensor([[10.],
        [43.]])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM