简体   繁体   English

在 PyTorch 中避免在批量张量中选择维度的循环

[英]Avoid Loop for Selecting Dimensions in a Batch Tensor in PyTorch

I have a batch tensor and another tensor having indices of the dimensions to select from batch tensor.我有一个批量张量和另一个张量,其尺寸索引从批量张量到 select。 At present, I am looping around batch tensor as shown below in the code snippet:目前,我正在循环批处理张量,如下面的代码片段所示:

import torch

# create tensors to represent our data in torch format
batch_size = 8
batch_data = torch.rand(batch_size, 3, 240, 320)

# notice that channels_id has 8 elements, i.e., = batch_size
channels_id = torch.tensor([2, 0, 2, 1, 0, 2, 1, 0])

This is how I am selecting dimensions inside a for loop and then stacking to convert a single tensor:这就是我在 for 循环中选择维度然后堆叠以转换单个张量的方式:

batch_out = torch.stack([batch_i[channel_i] for batch_i, channel_i in zip(batch_data, channels_id)])
batch_out.size()  # prints torch.Size([8, 240, 320])

It works fine.它工作正常。 However, is there a better PyTorch way to achieve the same?但是,是否有更好的 PyTorch 方法来实现同样的目标?

As per the hint from @Shai , I could make it work using the torch.gather function. Below is the complete code:根据@Shai的提示,我可以使用torch.gather function 使其工作。下面是完整的代码:

import torch

# create tensors to represent our data in torch format
batch_size = 8
batch_data = torch.rand(batch_size, 3, 240, 320)

# notice that channels_id has 8 elements, i.e., batch_size
channels_id = torch.tensor([2, 0, 2, 1, 0, 2, 1, 0])

# resizing channels_id to (8 , 1, 240, 320)
channels_id = channels_id.view(-1, 1, 1, 1).repeat((1, 1) + batch_data.size()[-2:])

batch_out = torch.gather(batch_data, 1, channels_id).squeeze()
batch_out.size()  # prints torch.Size([8, 240, 320])

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM