繁体   English   中英

如何并行化用于 PyTorch 的 for 循环?

[英]How can I parallelize a for loop for use in PyTorch?

我意识到,对于Pythonfor循环通常很慢。 我有一些与一些张量混淆的代码:


            for batch_index, mask_batch in enumerate(mask):
                mask_len = torch.sum(mask_batch).int()

                if mask_len == 0:
                    side_input = torch.zeros((max_inp_len, side_input.shape[1])).to(mask.device)
                else:

                    m_nonzero = mask_batch.nonzero().flatten()
                    first_nonzero = m_nonzero[0]
                    last_nonzero = m_nonzero[-1]

                    if side == 'left':
                        end_index = first_nonzero - 1
                        start_index = 0
                    elif side == 'right':
                        start_index = last_nonzero + 1
                        end_index = inputs[batch_index].size(1)

                    side_input = inputs[batch_index][start_index:end_index]

                    if end_index - start_index < max_inp_len:
                        pad_zeros = torch.zeros(
                            (max_inp_len - side_input.shape[0], side_input.shape[1])).to(mask.device)
                        if side == 'left':
                            side_input = torch.cat((pad_zeros, side_input), 0)
                        elif side == 'right':
                            side_input = torch.cat((side_input, pad_zeros), 0)

                side_inputs.append(side_input)

        return torch.stack(side_inputs)

我觉得这个循环真的让事情变慢了。 有没有办法让我在没有循环的情况下做到这一点?

Python 在任何给定进程中都没有真正的并行性。 You would have to spawn a ProcessPool and make the inside of your loop a function taking batch_index, mask_batch , then map that function over the mask object in your current for loop. 问题是,我不知道 PyTorch 是否能很好地配合这个。

像这样

def f(batch_index, mask_batch):
    mask_len = torch.sum(mask_batch).int()

    if mask_len == 0:
        side_input = torch.zeros((max_inp_len, side_input.shape[1])).to(mask.device)
    else:
        m_nonzero = mask_batch.nonzero().flatten()
        first_nonzero = m_nonzero[0]
        last_nonzero = m_nonzero[-1]

        if side == 'left':
            end_index = first_nonzero - 1
            start_index = 0
        elif side == 'right':
            start_index = last_nonzero + 1
            end_index = inputs[batch_index].size(1)

            side_input = inputs[batch_index][start_index:end_index]

            if end_index - start_index < max_inp_len:
                pad_zeros = torch.zeros((max_inp_len - side_input.shape[0], side_input.shape[1])).to(mask.device)
                if side == 'left':
                    side_input = torch.cat((pad_zeros, side_input), 0)
                elif side == 'right':
                    side_input = torch.cat((side_input, pad_zeros), 0)
    return side_input

您可以查看的其他内容是进一步矢量化代码。 PyTorch 和 Numpy 中的大多数内容都可以通过使用内置函数并将另一个维度添加到代表“循环”维度的张量上来向量化。 这将允许 PyTorch 为您处理并行度。

PyTorch 可能有一个设备概念,您可以在这些设备上放置不同的循环迭代,这再次要求您为此循环制作 function,并可能将其继续作为输入的设备。

最后,您可以查看像 Numba 或 torch.jit 这样的即时编译来为您执行自动矢量化。

如果mask的长度未知,这一切都不起作用(很可能)。 如果它的长度已知,我认为矢量化可能是您的最佳选择。

您应该创建一个包含循环迭代背后的逻辑的 function,并将其作为每个列的线程启动(请参阅此处的文档)。 您也可以使用asyncio库进行并发,但您可能会获得较少的改进。

可以在此处阅读为列表的每个元素生成线程的一个很好的示例。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM