简体   繁体   中英

Replace all zeros with last non-zero value in torch

Is there any efficient way to replace all zeros in a tensor with the last non-zero value in torch?

For example if I had the tensor:

tensor([[1, 0, 0, 4, 0, 5, 0, 0],
        [0, 3, 0, 6, 0, 0, 8, 0]])

The output should be:

tensor([[1, 1, 1, 4, 4, 5, 5, 5],
        [0, 3, 3, 6, 6, 6, 8, 8]])

I currently have the following code:

def replace_zeros_with_prev_nonzero(tensor):
    output = tensor.clone()
    for i in range(len(output)):
        prev_value = 0
        for j in range(len(tensor[i])):
            if tensor[i,j] == 0:
                output[i,j] = prev_value
            else:
                prev_value = tensor[i,j].item()      
    return output

But it feels though a bit clunky and I'm sure there has to be a better way to do this. So is it possible to write it in fewer lines, or better yet parallelise the operation without treating the tensors as arrays?

You can remove one of the loops by vectorising over 1st dimension.

def replace_zeros_with_prev_nonzero(tensor):
    output = tensor.clone()
    for i in range(1, tensor.shape[1]):
        mask = tensor[:, i] == 0
        output[mask, i] = output[mask, i-1]

    return output

output[mask, i] = output[mask, i-1] replaces 0 with the previous value (which itself will be replaced if 0 originally except for 0th index).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM