简体   繁体   English

在 PyTorch 中使用分布式数据并行 (DDP) 时,在训练期间检查点的正确方法是什么?

[英]What is the proper way to checkpoint during training when using distributed data parallel (DDP) in PyTorch?

I want (the proper and official - bug free way) to do:我想要(正确和官方 - 无错误的方式)做:

  1. resume from a checkpoint to continue training on multiple gpus从检查点恢复以继续在多个 GPU 上进行训练
  2. save checkpoint correctly during training with multiple gpus在使用多个 GPU 进行训练期间正确保存检查点

For that my guess is the following:为此,我的猜测如下:

  1. to do 1 we have all the processes load the checkpoint from the file, then call DDP(mdl) for each process.要做 1 我们让所有进程从文件中加载检查点,然后为每个进程调用DDP(mdl) I assume the checkpoint saved a ddp_mdl.module.state_dict() .我假设检查点保存了一个ddp_mdl.module.state_dict()
  2. to do 2 simply check who is rank = 0 and have that one do the torch.save({'model': ddp_mdl.module.state_dict()})要做 2 只需检查谁是 rank = 0 并让那个人做 torch.save({'model': ddp_mdl.module.state_dict()})

Approximate code:近似代码:

def save_ckpt(rank, ddp_model, path):
    if rank == 0:
        state = {'model': ddp_model.module.state_dict(),
             'optimizer': optimizer.state_dict(),
            }
        torch.save(state, path)

def load_ckpt(path, distributed, map_location=map_location=torch.device('cpu')):
    # loads to
    checkpoint = torch.load(path, map_location=map_location)
    model = Net(...)
    optimizer = ...
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    if distributed:
        model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
    return model

Is this correct?这个对吗?


One of the reasons that I am asking is that distributed code can go subtly wrong.我要问的原因之一是分布式代码可能会出现 go 微妙的错误。 I want to make sure this does not happen to me.我想确保这不会发生在我身上。 Of course I want to avoid deadlocks but that would be obvious if it happens to me (eg perhaps it could happen if all the processes somehow tried to open the same ckpt file at the same time. In that case I'd somehow make sure that only one of them loads it one at a time or have rank 0 only load it and then send it to the rest of the processes).当然我想避免死锁,但如果它发生在我身上就会很明显(例如,如果所有进程以某种方式试图同时打开同一个 ckpt 文件,可能会发生这种情况。在那种情况下,我会以某种方式确保其中只有一个一次加载一个,或者 rank 0 仅加载它,然后将其发送到进程的 rest)。

I am also asking because the official docs don't make sense to me .我也在问,因为官方文档对我没有意义 I will paste their code and explanation since links can die sometimes:我将粘贴他们的代码和解释,因为链接有时会失效:

Save and Load Checkpoints It's common to use torch.save and torch.load to checkpoint modules during training and recover from checkpoints.保存和加载检查点 在训练期间使用 torch.save 和 torch.load 来检查点模块并从检查点恢复是很常见的。 See SAVING AND LOADING MODELS for more details.有关详细信息,请参阅保存和加载模型。 When using DDP, one optimization is to save the model in only one process and then load it to all processes, reducing write overhead.使用 DDP 时,一项优化是将 model 仅保存在一个进程中,然后将其加载到所有进程中,从而减少写入开销。 This is correct because all processes start from the same parameters and gradients are synchronized in backward passes, and hence optimizers should keep setting parameters to the same values.这是正确的,因为所有过程都从相同的参数开始,并且梯度在反向传递中是同步的,因此优化器应该保持将参数设置为相同的值。 If you use this optimization, make sure all processes do not start loading before the saving is finished.如果您使用此优化,请确保在保存完成之前所有进程都不会开始加载。 Besides, when loading the module, you need to provide an appropriate map_location argument to prevent a process to step into others' devices.此外,在加载模块时,您需要提供适当的 map_location 参数以防止进程进入其他设备。 If map_location is missing, torch.load will first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices.如果缺少 map_location,torch.load 将首先将模块加载到 CPU,然后将每个参数复制到保存位置,这将导致同一台机器上的所有进程使用同一组设备。 For more advanced failure recovery and elasticity support, please refer to TorchElastic.有关更高级的故障恢复和弹性支持,请参阅 TorchElastic。

def demo_checkpoint(rank, world_size):
    print(f"Running DDP checkpoint example on rank {rank}.")
    setup(rank, world_size)

    model = ToyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
    if rank == 0:
        # All processes should see same parameters as they all start from same
        # random parameters and gradients are synchronized in backward passes.
        # Therefore, saving it in one process is sufficient.
        torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)

    # Use a barrier() to make sure that process 1 loads the model after process
    # 0 saves it.
    dist.barrier()
    # configure map_location properly
    map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
    ddp_model.load_state_dict(
        torch.load(CHECKPOINT_PATH, map_location=map_location))

    optimizer.zero_grad()
    outputs = ddp_model(torch.randn(20, 10))
    labels = torch.randn(20, 5).to(rank)
    loss_fn = nn.MSELoss()
    loss_fn(outputs, labels).backward()
    optimizer.step()

    # Not necessary to use a dist.barrier() to guard the file deletion below
    # as the AllReduce ops in the backward pass of DDP already served as
    # a synchronization.

    if rank == 0:
        os.remove(CHECKPOINT_PATH)

    cleanup()

Related:有关的:

I am looking at the official ImageNet example and here's how they do it.我正在查看官方ImageNet 示例,这是他们的做法。 First, they create the model in DDP mode :首先,他们在DDP 模式下创建 model :

model = ResNet50(...)
model = DDP(model,...)

At the save checkpoint , they check if it is the main process then save the state_dict :保存检查点,他们检查它是否是主进程,然后保存state_dict

import torch.distributed as dist

if dist.get_rank() == 0:  # check if main process, a simpler way compared to the link
    torch.save({'state_dict': model.state_dict(), ...},
                '/path/to/checkpoint.pth.tar')

During loading, they load the model and put it in DDP mode as usual, without the need of checking the rank:在加载过程中,他们加载 model 并像往常一样将其置于 DDP 模式,无需检查排名:

checkpoint = torch.load('/path/to/checkpoint.pth.tar')
model = ResNet50(...).load_state_dict(checkpoint['state_dict'])
model = DDP(...)

If you want to load it but not in DDP mode, it is a bit tricky since for some reason they save it with an extra suffix module .如果你想加载它但不是在 DDP 模式下,这有点棘手,因为出于某种原因,他们用额外的后缀module保存它。 As solved here , you have to do:如此解决,您必须执行以下操作:

state_dict = torch.load(checkpoint['state_dict'])
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 将 ddp 后端与 PyTorch Lightning 一起使用时,在整个验证集上进行验证 - Validate on entire validation set when using ddp backend with PyTorch Lightning PyTorch分布式训练时如何设置随机种子? - How to set random seed when it is in distributed training in PyTorch? 使用具有分布式数据并行的拥抱面训练器 - using huggingface Trainer with distributed data parallel Pytorch:训练期间权重不变 - Pytorch: Weights not changing during training 有没有办法替换 Pytorch 中用于 DDP(DistributedDataParallel) 的“allreduce_hook”? - Is there a way to replace the 'allreduce_hook' used for DDP(DistributedDataParallel) in Pytorch? 有没有办法查看Pytorch培训课程出了什么问题? - Is there a way to see what's going wrong with a training session in Pytorch? Tensorflow-GPU 在训练期间卡住保存检查点 - 也没有使用整个 GPU,不知道为什么 - Tensorflow-GPU getting stuck saving checkpoint during training - also not using entire GPU, not sure why 在 Windows 上将 pytorch 和 matplotlib 与 MKL 一起使用的正确方法是什么? - What is the proper way to use pytorch and matplotlib with MKL on Windows? 在 PyTorch 中加载用于推理的迁移学习模型的正确方法是什么? - What is the proper way to load a transfer learning model for inference in PyTorch? 使用 PyTorch 计算 95% 置信区间以进行分类和回归的正确方法是什么? - What is the proper way to compute 95% confidence intervals with PyTorch for classification and regression?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM