繁体   English   中英

PyTorch memory for 循环中的泄漏引用循环

[英]PyTorch memory leak reference cycle in for loop

在我的 Mac M1 GPU 上使用 PyTorch mps 接口迭代更新 PyTorch 中的张量时,我面临 memory 泄漏。 以下是复制行为的最小可重现示例:

import torch 

def leak_example(p1, device):
    
    t1 = torch.rand_like(p1, device = device) # torch.cat((torch.diff(ubar.detach(), dim=0).detach().clone(), torch.zeros_like(ubar.detach()[:1,:,:,:], dtype = torch.float32)), dim = 0)
    u1 = p1.detach() + 2 * (t1.detach())
    
    B = torch.rand_like(u1, device = device)
    mask = u1 < B
    
    a1 = u1.detach().clone()
    a1[~mask] = torch.rand_like(a1)[~mask]
    return a1

if torch.cuda.is_available(): # cuda gpus
    device = torch.device("cuda")
elif torch.backends.mps.is_available(): # mac gpus
    device = torch.device("mps")
torch.set_grad_enabled(False)
        
p1 = torch.rand(5, 5, 224, 224, device = device)
for i in range(10000):
    p1 = leak_example(p1, device)    

当我执行这个循环时,我的 Mac 的 GPU memory 稳步增长。 我已经尝试在 Google Colab 的 CUDA GPU 上运行它,它的行为似乎类似,随着循环的进行,GPU 的活动 memory、不可释放 memory 和分配 memory 增加。

我试过分离和克隆张量并使用 weakrefs,但无济于事。 有趣的是,如果我不将 leak_example 的leak_example重新分配给p1 ,行为就会消失,所以它看起来确实与递归分配有关。 有谁知道我该如何解决这个问题?

我想我找到了泄漏的原因,这是掩码分配。 用等效的torch.where()语句替换它会使泄漏消失。 我想这与masked_scatter没有在 PyTorch(还)中实现 MPS 支持有关?

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM