繁体   English   中英

在 PyTorch 中使用分布式数据并行 (DDP) 时,在训练期间检查点的正确方法是什么?

[英]What is the proper way to checkpoint during training when using distributed data parallel (DDP) in PyTorch?

我想要(正确和官方 - 无错误的方式)做:

  1. 从检查点恢复以继续在多个 GPU 上进行训练
  2. 在使用多个 GPU 进行训练期间正确保存检查点

为此,我的猜测如下:

  1. 要做 1 我们让所有进程从文件中加载检查点,然后为每个进程调用DDP(mdl) 我假设检查点保存了一个ddp_mdl.module.state_dict()
  2. 要做 2 只需检查谁是 rank = 0 并让那个人做 torch.save({'model': ddp_mdl.module.state_dict()})

近似代码:

def save_ckpt(rank, ddp_model, path):
    if rank == 0:
        state = {'model': ddp_model.module.state_dict(),
             'optimizer': optimizer.state_dict(),
            }
        torch.save(state, path)

def load_ckpt(path, distributed, map_location=map_location=torch.device('cpu')):
    # loads to
    checkpoint = torch.load(path, map_location=map_location)
    model = Net(...)
    optimizer = ...
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    if distributed:
        model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
    return model

这个对吗?


我要问的原因之一是分布式代码可能会出现 go 微妙的错误。 我想确保这不会发生在我身上。 当然我想避免死锁,但如果它发生在我身上就会很明显(例如,如果所有进程以某种方式试图同时打开同一个 ckpt 文件,可能会发生这种情况。在那种情况下,我会以某种方式确保其中只有一个一次加载一个,或者 rank 0 仅加载它,然后将其发送到进程的 rest)。

我也在问,因为官方文档对我没有意义 我将粘贴他们的代码和解释,因为链接有时会失效:

保存和加载检查点 在训练期间使用 torch.save 和 torch.load 来检查点模块并从检查点恢复是很常见的。 有关详细信息,请参阅保存和加载模型。 使用 DDP 时,一项优化是将 model 仅保存在一个进程中,然后将其加载到所有进程中,从而减少写入开销。 这是正确的,因为所有过程都从相同的参数开始,并且梯度在反向传递中是同步的,因此优化器应该保持将参数设置为相同的值。 如果您使用此优化,请确保在保存完成之前所有进程都不会开始加载。 此外,在加载模块时,您需要提供适当的 map_location 参数以防止进程进入其他设备。 如果缺少 map_location,torch.load 将首先将模块加载到 CPU,然后将每个参数复制到保存位置,这将导致同一台机器上的所有进程使用同一组设备。 有关更高级的故障恢复和弹性支持,请参阅 TorchElastic。

def demo_checkpoint(rank, world_size):
    print(f"Running DDP checkpoint example on rank {rank}.")
    setup(rank, world_size)

    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
    if rank == 0:
        # All processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes.
        # Therefore, saving it in one process is sufficient.
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

    # Use a barrier() to make sure that process 1 loads the model after process
    # 0 saves it.
    dist.barrier()
    # configure map_location properly
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    ddp_model.load_state_dict(
        torch.load(CHECKPOINT_PATH, map_location=map_location))

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn = nn.MSELoss()
    loss_fn(outputs, labels).backward()
    optimizer.step()

    # Not necessary to use a dist.barrier() to guard the file deletion below
    # as the AllReduce ops in the backward pass of DDP already served as
    # a synchronization.

    if rank == 0:
        os.remove(CHECKPOINT_PATH)

    cleanup()

有关的:

我正在查看官方ImageNet 示例,这是他们的做法。 首先,他们在DDP 模式下创建 model :

model = ResNet50(...)
model = DDP(model,...)

保存检查点,他们检查它是否是主进程,然后保存state_dict

import torch.distributed as dist

if dist.get_rank() == 0:  # check if main process, a simpler way compared to the link
    torch.save({'state_dict': model.state_dict(), ...},
                '/path/to/checkpoint.pth.tar')

在加载过程中,他们加载 model 并像往常一样将其置于 DDP 模式,无需检查排名:

checkpoint = torch.load('/path/to/checkpoint.pth.tar')
model = ResNet50(...).load_state_dict(checkpoint['state_dict'])
model = DDP(...)

如果你想加载它但不是在 DDP 模式下,这有点棘手,因为出于某种原因,他们用额外的后缀module保存它。 如此解决,您必须执行以下操作:

state_dict = torch.load(checkpoint['state_dict'])
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM