[英]pytorch - reciprocal of torch.gather
Given an input tensor x
and a tensor of indices idxs
, I want to retrieve all elements of x
whose index is not present in idxs
.给定输入张量
x
和索引张量idxs
,我想检索x
索引不存在于idxs
所有元素。 That is, taking the opposite of the torch.gather
function output.也就是说,与
torch.gather
函数输出相反。
Example with torch.gather
:使用
torch.gather
示例:
>>> x = torch.arange(30).reshape(3,10)
>>> idxs = torch.tensor([[1,2,3], [4,5,6], [7,8,9]], dtype=torch.long)
>>> torch.gather(x, 1, idxs)
tensor([[ 1, 2, 3],
[14, 15, 16],
[27, 28, 29]])
What indeed I want to achieve is我真正想要实现的是
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])
What could it be an effective and efficient implementation, possibly employing torch utilities?什么是有效和高效的实施,可能使用火炬实用程序? I wouldn't like to use any for-loops.
我不想使用任何 for 循环。
I'm assuming idxs
has only unique elements in its deepest dimension.我假设
idxs
在其最深的维度中只有独特的元素。 For example idxs
would be the result of calling torch.topk
.例如
idxs
将是调用torch.topk
的结果。
You could be looking to construct a tensor of shape (x.size(0), x.size(1)-idxs.size(1))
(here (3, 7)
).您可能希望构造一个形状为
(x.size(0), x.size(1)-idxs.size(1))
的张量(此处为(3, 7)
)。 Which would correspond to the complementary indices of idxs
, with regard to the shape of x
, ie :这将对应于
idxs
的互补索引,关于x
的形状,即:
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])
I propose to first build a tensor shaped like x
that would reveal the positions we want to keep and those we want to discard, a sort of mask.我建议首先构建一个形状像
x
的张量,它会显示我们想要保留的位置和我们想要丢弃的位置,一种掩码。 This can be done using torch.scatter
.这可以使用
torch.scatter
来完成。 This essentially scatters 0
s at desired location, namely m[i, idxs[i][j]] = 0
:这基本上将
0
s 分散到所需位置,即m[i, idxs[i][j]] = 0
:
>>> m = torch.ones_like(x).scatter(1, idxs, 0)
tensor([[1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0]])
Then grab the non zeros (the complementary part of idxs
).然后获取非零(
idxs
的补充部分)。 Select the 2nd indices on axis=1
, and reshape according to the target tensor:选择
axis=1
上的第二个索引,并根据目标张量重塑:
>>> idxs_ = m.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
tensor([[0, 4, 5, 6, 7, 8, 9],
[0, 1, 2, 3, 7, 8, 9],
[0, 1, 2, 3, 4, 5, 6]])
Now you know what to do, right?现在你知道该怎么做了,对吧? Same as for the
torch.gather
example you gave, but this time with idxs_
:与您提供的
torch.gather
示例相同,但这次使用idxs_
:
>>> torch.gather(x, 1, idxs_)
tensor([[ 0, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26]])
In summary:总之:
>>> idxs_ = torch.ones_like(x).scatter(1, idxs, 0) \
.nonzero()[:, 1].reshape(-1, x.size(1) - idxs.size(1))
>>> torch.gather(x, 1, idxs_)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.