[英]Replace all zeros with last non-zero value in torch
有沒有什么有效的方法可以用火炬中最后一個非零值替換張量中的所有零?
例如,如果我有張量:
tensor([[1, 0, 0, 4, 0, 5, 0, 0],
[0, 3, 0, 6, 0, 0, 8, 0]])
輸出應該是:
tensor([[1, 1, 1, 4, 4, 5, 5, 5],
[0, 3, 3, 6, 6, 6, 8, 8]])
我目前有以下代碼:
def replace_zeros_with_prev_nonzero(tensor):
output = tensor.clone()
for i in range(len(output)):
prev_value = 0
for j in range(len(tensor[i])):
if tensor[i,j] == 0:
output[i,j] = prev_value
else:
prev_value = tensor[i,j].item()
return output
但感覺雖然有點笨拙,我相信必須有更好的方法來做到這一點。 那么是否可以用更少的行來編寫它,或者更好地並行化操作而不將張量視為數組?
您可以通過對一維進行矢量化來刪除其中一個循環。
def replace_zeros_with_prev_nonzero(tensor):
output = tensor.clone()
for i in range(1, tensor.shape[1]):
mask = tensor[:, i] == 0
output[mask, i] = output[mask, i-1]
return output
output[mask, i] = output[mask, i-1]
將 0 替換為前一個值(如果最初為 0,則它本身將被替換,除了第 0 個索引)。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.