[英]How can I parallelize a for loop for use in PyTorch?
我意识到,对于Python
, for
循环通常很慢。 我有一些与一些张量混淆的代码:
for batch_index, mask_batch in enumerate(mask):
mask_len = torch.sum(mask_batch).int()
if mask_len == 0:
side_input = torch.zeros((max_inp_len, side_input.shape[1])).to(mask.device)
else:
m_nonzero = mask_batch.nonzero().flatten()
first_nonzero = m_nonzero[0]
last_nonzero = m_nonzero[-1]
if side == 'left':
end_index = first_nonzero - 1
start_index = 0
elif side == 'right':
start_index = last_nonzero + 1
end_index = inputs[batch_index].size(1)
side_input = inputs[batch_index][start_index:end_index]
if end_index - start_index < max_inp_len:
pad_zeros = torch.zeros(
(max_inp_len - side_input.shape[0], side_input.shape[1])).to(mask.device)
if side == 'left':
side_input = torch.cat((pad_zeros, side_input), 0)
elif side == 'right':
side_input = torch.cat((side_input, pad_zeros), 0)
side_inputs.append(side_input)
return torch.stack(side_inputs)
我觉得这个循环真的让事情变慢了。 有没有办法让我在没有循环的情况下做到这一点?
Python 在任何给定进程中都没有真正的并行性。 You would have to spawn a ProcessPool and make the inside of your loop a function taking batch_index, mask_batch
, then map that function over the mask
object in your current for loop. 问题是,我不知道 PyTorch 是否能很好地配合这个。
像这样
def f(batch_index, mask_batch):
mask_len = torch.sum(mask_batch).int()
if mask_len == 0:
side_input = torch.zeros((max_inp_len, side_input.shape[1])).to(mask.device)
else:
m_nonzero = mask_batch.nonzero().flatten()
first_nonzero = m_nonzero[0]
last_nonzero = m_nonzero[-1]
if side == 'left':
end_index = first_nonzero - 1
start_index = 0
elif side == 'right':
start_index = last_nonzero + 1
end_index = inputs[batch_index].size(1)
side_input = inputs[batch_index][start_index:end_index]
if end_index - start_index < max_inp_len:
pad_zeros = torch.zeros((max_inp_len - side_input.shape[0], side_input.shape[1])).to(mask.device)
if side == 'left':
side_input = torch.cat((pad_zeros, side_input), 0)
elif side == 'right':
side_input = torch.cat((side_input, pad_zeros), 0)
return side_input
您可以查看的其他内容是进一步矢量化代码。 PyTorch 和 Numpy 中的大多数内容都可以通过使用内置函数并将另一个维度添加到代表“循环”维度的张量上来向量化。 这将允许 PyTorch 为您处理并行度。
PyTorch 可能有一个设备概念,您可以在这些设备上放置不同的循环迭代,这再次要求您为此循环制作 function,并可能将其继续作为输入的设备。
最后,您可以查看像 Numba 或 torch.jit 这样的即时编译来为您执行自动矢量化。
如果mask
的长度未知,这一切都不起作用(很可能)。 如果它的长度已知,我认为矢量化可能是您的最佳选择。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.