繁体   English   中英

Pytorch 在火炬张量之间移动向量的操作

[英]Pytorch operation for moving vectors between torch tensors

假设我们有火炬张量:

A: with shape BxHxW and values in {0,1}, where 0 and 1 are classes
B: with shape Bx2xD and real values, where D is the dimensionality of our vector

We want to create a new tensor of shape BxDxHxW that holds in each index specified in the spatial dimension (HxW), the vector that corresponds to its class (specified by A).

pytorch 中有 function 实现了吗? 我试过 torch scatter 但认为情况并非如此。

您实际上是在寻找反向操作,即使用包含在另一个张量中的索引从一个张量收集值。 这是处理这种索引场景并毫不费力地应用torch.gather的规范答案。

让我们用虚拟数据设置一个最小的例子:

>>> b = 2; d = 3; h = 2; w = 1
>>> A = torch.randint(0, 2, (b,h,w)) # bhw
>>> B = torch.rand(b,2,d) # b2d
  1. 根据您的问题定义要执行的索引规则,在这里:

     # out[b, d, h, w] = B[b, A[b, h, w]]
  2. 我们正在使用A中的值寻找B的第二维的某种索引。 应用torch.gather时,所有三个张量(输入、索引器和输出)必须具有相同的维度数相同的维度大小,但要索引的维度除外,这里的dim=1 观察我们的案例,我们必须坚持这种模式:

     # out[b, 1, d, h, w] = B[b, A[b, 1, d, h, w], d, h, w]
  3. 因此,为了解释这种变化,我们需要在输入张量和索引张量上取消压缩/扩展额外的维度。 因此,要坚持上述形状,我们可以这样做:

    首先,我们取消压缩A上的二维:

     >>> A_ = A[:,None,None].expand(-1,1,d,-1,-1)

    其次,我们取消压缩B上的两个维度:

     >>> B_ = B[..., None, None].expand(-1,-1,-1,h,w)

    请注意,扩展维度不会执行复制。 它只是张量基础数据的一个视图。 在此步骤中, A_最终具有(b, 1, d, h, w)的形状,而B_具有(b, 2, d, h, w)的形状。

  4. 现在,我们可以使用A_B_dim=1上简单地应用torch.gather

     >>> out = B_.gather(dim=1, index=A_)

    我们必须为dim=1使用 singleton 维度,这样我们就可以在生成的张量上压缩它。 这是您想要的形状(b, d, h, w)的结果:

     >>> out[:,0]

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM