[英]Applying torch.combinations on multidimensional tensor or tuple of tensors in PyTorch?
使用 PyTorch, torch.combinations
只会将一维张量作为输入,但我想将其应用于多维张量中的每个一维张量。
inp = torch.tensor([[1, 2, 3],
[2, 3, 4]])
torch.combinations((inp), r=2)
结果是一个错误,说我无法将其应用于该形状,但我想将其分别应用于[1, 2, 3]
和[2, 3, 4]
。 我不能一一做,因为我的想法是将其应用于大量数据。
inp = torch.tensor([[1,2,3],[2,3,4]])
inp_tuple = torch.unbind(inp)
print(inp_tuple)
(tensor([1, 2, 3]), tensor([2, 3, 4]))
torch.combinations((inp_tuple), r=2)
我还尝试取消绑定张量并将其应用于张量的元组,但它给出了一个错误,说它不能应用于元组。
有什么方法可以让torch.combinations
自动应用于多维张量中的每个单独的一维张量或张量元组中的每个张量? 如果没有,是否有任何替代方法可以实现多维张量的每个单独部分的所有组合?
Function torch.combinations
返回一维输入向量中包含的元素的大小r
的所有可能组合。 不支持多维输入的原因可能是您无法保证输入中的不同向量具有完全相同数量的唯一元素。 显然,如果其中一个向量具有重复元素,那么您最终会得到一组比另一组更大的组合,这根本不可能用同质 PyTorch 张量来表示。
因此,从那里开始,我将假设输入张量inp
是一个 2D 张量形状(N, C)
,其中N
个向量中的每一个都包含C
唯一元素。 您提供的示例符合此要求,因为两个向量各有三个唯一元素: {1, 2, 3}
和{2, 3, 4}
。
>>> inp = torch.tensor([[1,2,3],[2,3,4]])
这个想法是将torch.combinations
应用于长度等于我们向量的排列张量。 然后我们可以使用它们作为索引来收集输入张量中不同向量中的值。
我们可以使用以下内容检索排列的所有组合:
>>> c = torch.combinations(torch.arange(inp.size(1)), r=2)
tensor([[0, 1],
[0, 2],
[1, 2]])
然后我们需要重塑和扩展c
inp
它们在维度数上匹配:
>>> x = inp[:,None].expand(-1,len(c),-1)
tensor([[[1, 2, 3],
[1, 2, 3],
[1, 2, 3]],
[[2, 3, 4],
[2, 3, 4],
[2, 3, 4]]])
>>> idx = c[None].expand(len(x), -1, -1)
tensor([[[0, 1],
[0, 2],
[1, 2]],
[[0, 1],
[0, 2],
[1, 2]]])
最后,我们可以在x
上应用torch.gather
,在dim=2
上应用idx
。 这将返回一个 3D 张量out
这样:
out[i][j][k] = x[i][j][index[i][j][k]]
让我们调用torch.gather
:
>>> x.gather(dim=2, index=idx)
tensor([[[1, 2],
[1, 3],
[2, 3]],
[[2, 3],
[2, 4],
[3, 4]]])
这是期望的结果。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.