[英]What is the proper way to checkpoint during training when using distributed data parallel (DDP) in PyTorch?
I want (the proper and official - bug free way) to do:我想要(正确和官方 - 无错误的方式)做:
For that my guess is the following:为此,我的猜测如下:
DDP(mdl)
for each process.DDP(mdl)
。 I assume the checkpoint saved a ddp_mdl.module.state_dict()
.ddp_mdl.module.state_dict()
。 Approximate code:近似代码:
def save_ckpt(rank, ddp_model, path):
if rank == 0:
state = {'model': ddp_model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(state, path)
def load_ckpt(path, distributed, map_location=map_location=torch.device('cpu')):
# loads to
checkpoint = torch.load(path, map_location=map_location)
model = Net(...)
optimizer = ...
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
if distributed:
model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
return model
Is this correct?这个对吗?
One of the reasons that I am asking is that distributed code can go subtly wrong.我要问的原因之一是分布式代码可能会出现 go 微妙的错误。 I want to make sure this does not happen to me.
我想确保这不会发生在我身上。 Of course I want to avoid deadlocks but that would be obvious if it happens to me (eg perhaps it could happen if all the processes somehow tried to open the same ckpt file at the same time. In that case I'd somehow make sure that only one of them loads it one at a time or have rank 0 only load it and then send it to the rest of the processes).
当然我想避免死锁,但如果它发生在我身上就会很明显(例如,如果所有进程以某种方式试图同时打开同一个 ckpt 文件,可能会发生这种情况。在那种情况下,我会以某种方式确保其中只有一个一次加载一个,或者 rank 0 仅加载它,然后将其发送到进程的 rest)。
I am also asking because the official docs don't make sense to me .我也在问,因为官方文档对我没有意义。 I will paste their code and explanation since links can die sometimes:
我将粘贴他们的代码和解释,因为链接有时会失效:
Save and Load Checkpoints It's common to use torch.save and torch.load to checkpoint modules during training and recover from checkpoints.
保存和加载检查点 在训练期间使用 torch.save 和 torch.load 来检查点模块并从检查点恢复是很常见的。 See SAVING AND LOADING MODELS for more details.
有关详细信息,请参阅保存和加载模型。 When using DDP, one optimization is to save the model in only one process and then load it to all processes, reducing write overhead.
使用 DDP 时,一项优化是将 model 仅保存在一个进程中,然后将其加载到所有进程中,从而减少写入开销。 This is correct because all processes start from the same parameters and gradients are synchronized in backward passes, and hence optimizers should keep setting parameters to the same values.
这是正确的,因为所有过程都从相同的参数开始,并且梯度在反向传递中是同步的,因此优化器应该保持将参数设置为相同的值。 If you use this optimization, make sure all processes do not start loading before the saving is finished.
如果您使用此优化,请确保在保存完成之前所有进程都不会开始加载。 Besides, when loading the module, you need to provide an appropriate map_location argument to prevent a process to step into others' devices.
此外,在加载模块时,您需要提供适当的 map_location 参数以防止进程进入其他设备。 If map_location is missing, torch.load will first load the module to CPU and then copy each parameter to where it was saved, which would result in all processes on the same machine using the same set of devices.
如果缺少 map_location,torch.load 将首先将模块加载到 CPU,然后将每个参数复制到保存位置,这将导致同一台机器上的所有进程使用同一组设备。 For more advanced failure recovery and elasticity support, please refer to TorchElastic.
有关更高级的故障恢复和弹性支持,请参阅 TorchElastic。
def demo_checkpoint(rank, world_size):
print(f"Running DDP checkpoint example on rank {rank}.")
setup(rank, world_size)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
# All processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# configure map_location properly
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()
optimizer.step()
# Not necessary to use a dist.barrier() to guard the file deletion below
# as the AllReduce ops in the backward pass of DDP already served as
# a synchronization.
if rank == 0:
os.remove(CHECKPOINT_PATH)
cleanup()
Related:有关的:
I am looking at the official ImageNet example and here's how they do it.我正在查看官方ImageNet 示例,这是他们的做法。 First, they create the model in DDP mode :
首先,他们在DDP 模式下创建 model :
model = ResNet50(...)
model = DDP(model,...)
At the save checkpoint , they check if it is the main process then save the state_dict
:在保存检查点,他们检查它是否是主进程,然后保存
state_dict
:
import torch.distributed as dist
if dist.get_rank() == 0: # check if main process, a simpler way compared to the link
torch.save({'state_dict': model.state_dict(), ...},
'/path/to/checkpoint.pth.tar')
During loading, they load the model and put it in DDP mode as usual, without the need of checking the rank:在加载过程中,他们加载 model 并像往常一样将其置于 DDP 模式,无需检查排名:
checkpoint = torch.load('/path/to/checkpoint.pth.tar')
model = ResNet50(...).load_state_dict(checkpoint['state_dict'])
model = DDP(...)
If you want to load it but not in DDP mode, it is a bit tricky since for some reason they save it with an extra suffix module
.如果你想加载它但不是在 DDP 模式下,这有点棘手,因为出于某种原因,他们用额外的后缀
module
保存它。 As solved here , you have to do:如此处解决,您必须执行以下操作:
state_dict = torch.load(checkpoint['state_dict'])
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.