简体   繁体   中英

How can I parallelize a for loop for use in PyTorch?

I realize that for loops are slow with Python in general. I have some code that messes around with some tensors:


            for batch_index, mask_batch in enumerate(mask):
                mask_len = torch.sum(mask_batch).int()

                if mask_len == 0:
                    side_input = torch.zeros((max_inp_len, side_input.shape[1])).to(mask.device)
                else:

                    m_nonzero = mask_batch.nonzero().flatten()
                    first_nonzero = m_nonzero[0]
                    last_nonzero = m_nonzero[-1]

                    if side == 'left':
                        end_index = first_nonzero - 1
                        start_index = 0
                    elif side == 'right':
                        start_index = last_nonzero + 1
                        end_index = inputs[batch_index].size(1)

                    side_input = inputs[batch_index][start_index:end_index]

                    if end_index - start_index < max_inp_len:
                        pad_zeros = torch.zeros(
                            (max_inp_len - side_input.shape[0], side_input.shape[1])).to(mask.device)
                        if side == 'left':
                            side_input = torch.cat((pad_zeros, side_input), 0)
                        elif side == 'right':
                            side_input = torch.cat((side_input, pad_zeros), 0)

                side_inputs.append(side_input)

        return torch.stack(side_inputs)

I feel like this loop is REALLY slowing things down. Is there some way for me to do it without the loop?

Python does not have true parallelism within any given process. You would have to spawn a ProcessPool and make the inside of your loop a function taking batch_index, mask_batch , then map that function over the mask object in your current for loop. Thing is, I don't know if PyTorch will play nicely with this.

Like so

def f(batch_index, mask_batch):
    mask_len = torch.sum(mask_batch).int()

    if mask_len == 0:
        side_input = torch.zeros((max_inp_len, side_input.shape[1])).to(mask.device)
    else:
        m_nonzero = mask_batch.nonzero().flatten()
        first_nonzero = m_nonzero[0]
        last_nonzero = m_nonzero[-1]

        if side == 'left':
            end_index = first_nonzero - 1
            start_index = 0
        elif side == 'right':
            start_index = last_nonzero + 1
            end_index = inputs[batch_index].size(1)

            side_input = inputs[batch_index][start_index:end_index]

            if end_index - start_index < max_inp_len:
                pad_zeros = torch.zeros((max_inp_len - side_input.shape[0], side_input.shape[1])).to(mask.device)
                if side == 'left':
                    side_input = torch.cat((pad_zeros, side_input), 0)
                elif side == 'right':
                    side_input = torch.cat((side_input, pad_zeros), 0)
    return side_input

The other things you can look at are further vectorizing the code. Most things in PyTorch and Numpy can be vectorized away by using builtin functions and adding another dimension onto your tensors that represents the "loop" dimension. This will allow PyTorch to handle the parallelism for you.

PyTorch may have a concept of devices that you can put different iterations of the loop on, again this will require you to make a function for this loop and maybe take the device it goes on as an input.

Lastly you can look into just in time compliation like Numba or torch.jit to perform auto-vectorization for you.

None of this will work (most likely) if the mask is of an unknown length. If it is of a known length, I think vectorization, as hard as it is, is likely your best choice.

You should create a function containing the logic behind a loop iteration, and launch it as a thread for each column (see docs here ). You could also use asyncio library for concurrency but you might probably obtain less improvements.

A good example a spawning a thread for each element of a list can be read here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM