繁体   English   中英

用火炬中最后一个非零值替换所有零

[英]Replace all zeros with last non-zero value in torch

有没有什么有效的方法可以用火炬中最后一个非零值替换张量中的所有零?

例如,如果我有张量:

tensor([[1, 0, 0, 4, 0, 5, 0, 0],
        [0, 3, 0, 6, 0, 0, 8, 0]])

输出应该是:

tensor([[1, 1, 1, 4, 4, 5, 5, 5],
        [0, 3, 3, 6, 6, 6, 8, 8]])

我目前有以下代码:

def replace_zeros_with_prev_nonzero(tensor):
    output = tensor.clone()
    for i in range(len(output)):
        prev_value = 0
        for j in range(len(tensor[i])):
            if tensor[i,j] == 0:
                output[i,j] = prev_value
            else:
                prev_value = tensor[i,j].item()      
    return output

但感觉虽然有点笨拙,我相信必须有更好的方法来做到这一点。 那么是否可以用更少的行来编写它,或者更好地并行化操作而不将张量视为数组?

您可以通过对一维进行矢量化来删除其中一个循环。

def replace_zeros_with_prev_nonzero(tensor):
    output = tensor.clone()
    for i in range(1, tensor.shape[1]):
        mask = tensor[:, i] == 0
        output[mask, i] = output[mask, i-1]

    return output

output[mask, i] = output[mask, i-1]将 0 替换为前一个值(如果最初为 0,则它本身将被替换,除了第 0 个索引)。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM