I realize that for
loops are slow with Python
in general. I have some code that messes around with some tensors:
for batch_index, mask_batch in enumerate(mask):
mask_len = torch.sum(mask_batch).int()
if mask_len == 0:
side_input = torch.zeros((max_inp_len, side_input.shape[1])).to(mask.device)
else:
m_nonzero = mask_batch.nonzero().flatten()
first_nonzero = m_nonzero[0]
last_nonzero = m_nonzero[-1]
if side == 'left':
end_index = first_nonzero - 1
start_index = 0
elif side == 'right':
start_index = last_nonzero + 1
end_index = inputs[batch_index].size(1)
side_input = inputs[batch_index][start_index:end_index]
if end_index - start_index < max_inp_len:
pad_zeros = torch.zeros(
(max_inp_len - side_input.shape[0], side_input.shape[1])).to(mask.device)
if side == 'left':
side_input = torch.cat((pad_zeros, side_input), 0)
elif side == 'right':
side_input = torch.cat((side_input, pad_zeros), 0)
side_inputs.append(side_input)
return torch.stack(side_inputs)
I feel like this loop is REALLY slowing things down. Is there some way for me to do it without the loop?
Python does not have true parallelism within any given process. You would have to spawn a ProcessPool and make the inside of your loop a function taking batch_index, mask_batch
, then map that function over the mask
object in your current for loop. Thing is, I don't know if PyTorch will play nicely with this.
Like so
def f(batch_index, mask_batch):
mask_len = torch.sum(mask_batch).int()
if mask_len == 0:
side_input = torch.zeros((max_inp_len, side_input.shape[1])).to(mask.device)
else:
m_nonzero = mask_batch.nonzero().flatten()
first_nonzero = m_nonzero[0]
last_nonzero = m_nonzero[-1]
if side == 'left':
end_index = first_nonzero - 1
start_index = 0
elif side == 'right':
start_index = last_nonzero + 1
end_index = inputs[batch_index].size(1)
side_input = inputs[batch_index][start_index:end_index]
if end_index - start_index < max_inp_len:
pad_zeros = torch.zeros((max_inp_len - side_input.shape[0], side_input.shape[1])).to(mask.device)
if side == 'left':
side_input = torch.cat((pad_zeros, side_input), 0)
elif side == 'right':
side_input = torch.cat((side_input, pad_zeros), 0)
return side_input
The other things you can look at are further vectorizing the code. Most things in PyTorch and Numpy can be vectorized away by using builtin functions and adding another dimension onto your tensors that represents the "loop" dimension. This will allow PyTorch to handle the parallelism for you.
PyTorch may have a concept of devices that you can put different iterations of the loop on, again this will require you to make a function for this loop and maybe take the device it goes on as an input.
Lastly you can look into just in time compliation like Numba or torch.jit to perform auto-vectorization for you.
None of this will work (most likely) if the mask
is of an unknown length. If it is of a known length, I think vectorization, as hard as it is, is likely your best choice.
You should create a function containing the logic behind a loop iteration, and launch it as a thread for each column (see docs here ). You could also use asyncio library for concurrency but you might probably obtain less improvements.
A good example a spawning a thread for each element of a list can be read here .
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.