[英]Replace all zeros with last non-zero value in torch
有没有什么有效的方法可以用火炬中最后一个非零值替换张量中的所有零?
例如,如果我有张量:
tensor([[1, 0, 0, 4, 0, 5, 0, 0],
[0, 3, 0, 6, 0, 0, 8, 0]])
输出应该是:
tensor([[1, 1, 1, 4, 4, 5, 5, 5],
[0, 3, 3, 6, 6, 6, 8, 8]])
我目前有以下代码:
def replace_zeros_with_prev_nonzero(tensor):
output = tensor.clone()
for i in range(len(output)):
prev_value = 0
for j in range(len(tensor[i])):
if tensor[i,j] == 0:
output[i,j] = prev_value
else:
prev_value = tensor[i,j].item()
return output
但感觉虽然有点笨拙,我相信必须有更好的方法来做到这一点。 那么是否可以用更少的行来编写它,或者更好地并行化操作而不将张量视为数组?
您可以通过对一维进行矢量化来删除其中一个循环。
def replace_zeros_with_prev_nonzero(tensor):
output = tensor.clone()
for i in range(1, tensor.shape[1]):
mask = tensor[:, i] == 0
output[mask, i] = output[mask, i-1]
return output
output[mask, i] = output[mask, i-1]
将 0 替换为前一个值(如果最初为 0,则它本身将被替换,除了第 0 个索引)。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.