简体   繁体   English

pytorch - torch.gather 的倒数

[英]pytorch - reciprocal of torch.gather

Given an input tensor x and a tensor of indices idxs , I want to retrieve all elements of x whose index is not present in idxs .给定输入张量x和索引张量idxs ,我想检索x索引不存在于idxs所有元素。 That is, taking the opposite of the torch.gather function output.也就是说,与torch.gather函数输出相反。

Example with torch.gather :使用torch.gather示例:

>>> x = torch.arange(30).reshape(3,10)
>>> idxs = torch.tensor([[1,2,3], [4,5,6], [7,8,9]], dtype=torch.long)
>>> torch.gather(x, 1, idxs)
tensor([[ 1,  2,  3],
        [14, 15, 16],
        [27, 28, 29]])

What indeed I want to achieve is我真正想要实现的是

tensor([[ 0,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26]])

What could it be an effective and efficient implementation, possibly employing torch utilities?什么是有效和高效的实施,可能使用火炬实用程序? I wouldn't like to use any for-loops.我不想使用任何 for 循环。

I'm assuming idxs has only unique elements in its deepest dimension.我假设idxs在其最深的维度中只有独特的元素。 For example idxs would be the result of calling torch.topk .例如idxs将是调用torch.topk的结果。

You could be looking to construct a tensor of shape (x.size(0), x.size(1)-idxs.size(1)) (here (3, 7) ).您可能希望构造一个形状为(x.size(0), x.size(1)-idxs.size(1))的张量(此处为(3, 7) )。 Which would correspond to the complementary indices of idxs , with regard to the shape of x , ie :这将对应于idxs的互补索引,关于x的形状,

tensor([[0, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6]])

I propose to first build a tensor shaped like x that would reveal the positions we want to keep and those we want to discard, a sort of mask.我建议首先构建一个形状像x的张量,它会显示我们想要保留的位置和我们想要丢弃的位置,一种掩码。 This can be done using torch.scatter .这可以使用torch.scatter来完成。 This essentially scatters 0 s at desired location, namely m[i, idxs[i][j]] = 0 :这基本上将0 s 分散到所需位置,即m[i, idxs[i][j]] = 0

>>> m = torch.ones_like(x).scatter(1, idxs, 0)
tensor([[1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])

Then grab the non zeros (the complementary part of idxs ).然后获取非零( idxs的补充部分)。 Select the 2nd indices on axis=1 , and reshape according to the target tensor:选择axis=1上的第二个索引,并根据目标张量重塑:

>>> idxs_ = m.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
tensor([[0, 4, 5, 6, 7, 8, 9],
        [0, 1, 2, 3, 7, 8, 9],
        [0, 1, 2, 3, 4, 5, 6]])

Now you know what to do, right?现在你知道该怎么做了,对吧? Same as for the torch.gather example you gave, but this time with idxs_ :与您提供的torch.gather示例相同,但这次使用idxs_

>>> torch.gather(x, 1, idxs_)
tensor([[ 0,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26]])

In summary:总之:

>>> idxs_ = torch.ones_like(x).scatter(1, idxs, 0) \
        .nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))

>>> torch.gather(x, 1, idxs_)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM