![](/img/trans.png)
[英]Validate on entire validation set when using ddp backend with PyTorch Lightning
[英]What is the proper way to checkpoint during training when using distributed data parallel (DDP) in PyTorch?
我想要(正确和官方 - 无错误的方式)做:
为此,我的猜测如下:
DDP(mdl)
。 我假设检查点保存了一个ddp_mdl.module.state_dict()
。近似代码:
def save_ckpt(rank, ddp_model, path):
if rank == 0:
state = {'model': ddp_model.module.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(state, path)
def load_ckpt(path, distributed, map_location=map_location=torch.device('cpu')):
# loads to
checkpoint = torch.load(path, map_location=map_location)
model = Net(...)
optimizer = ...
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
if distributed:
model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
return model
这个对吗?
我要问的原因之一是分布式代码可能会出现 go 微妙的错误。 我想确保这不会发生在我身上。 当然我想避免死锁,但如果它发生在我身上就会很明显(例如,如果所有进程以某种方式试图同时打开同一个 ckpt 文件,可能会发生这种情况。在那种情况下,我会以某种方式确保其中只有一个一次加载一个,或者 rank 0 仅加载它,然后将其发送到进程的 rest)。
我也在问,因为官方文档对我没有意义。 我将粘贴他们的代码和解释,因为链接有时会失效:
保存和加载检查点 在训练期间使用 torch.save 和 torch.load 来检查点模块并从检查点恢复是很常见的。 有关详细信息,请参阅保存和加载模型。 使用 DDP 时,一项优化是将 model 仅保存在一个进程中,然后将其加载到所有进程中,从而减少写入开销。 这是正确的,因为所有过程都从相同的参数开始,并且梯度在反向传递中是同步的,因此优化器应该保持将参数设置为相同的值。 如果您使用此优化,请确保在保存完成之前所有进程都不会开始加载。 此外,在加载模块时,您需要提供适当的 map_location 参数以防止进程进入其他设备。 如果缺少 map_location,torch.load 将首先将模块加载到 CPU,然后将每个参数复制到保存位置,这将导致同一台机器上的所有进程使用同一组设备。 有关更高级的故障恢复和弹性支持,请参阅 TorchElastic。
def demo_checkpoint(rank, world_size):
print(f"Running DDP checkpoint example on rank {rank}.")
setup(rank, world_size)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
CHECKPOINT_PATH = tempfile.gettempdir() + "/model.checkpoint"
if rank == 0:
# All processes should see same parameters as they all start from same
# random parameters and gradients are synchronized in backward passes.
# Therefore, saving it in one process is sufficient.
torch.save(ddp_model.state_dict(), CHECKPOINT_PATH)
# Use a barrier() to make sure that process 1 loads the model after process
# 0 saves it.
dist.barrier()
# configure map_location properly
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
ddp_model.load_state_dict(
torch.load(CHECKPOINT_PATH, map_location=map_location))
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10))
labels = torch.randn(20, 5).to(rank)
loss_fn = nn.MSELoss()
loss_fn(outputs, labels).backward()
optimizer.step()
# Not necessary to use a dist.barrier() to guard the file deletion below
# as the AllReduce ops in the backward pass of DDP already served as
# a synchronization.
if rank == 0:
os.remove(CHECKPOINT_PATH)
cleanup()
有关的:
我正在查看官方ImageNet 示例,这是他们的做法。 首先,他们在DDP 模式下创建 model :
model = ResNet50(...)
model = DDP(model,...)
在保存检查点,他们检查它是否是主进程,然后保存state_dict
:
import torch.distributed as dist
if dist.get_rank() == 0: # check if main process, a simpler way compared to the link
torch.save({'state_dict': model.state_dict(), ...},
'/path/to/checkpoint.pth.tar')
在加载过程中,他们加载 model 并像往常一样将其置于 DDP 模式,无需检查排名:
checkpoint = torch.load('/path/to/checkpoint.pth.tar')
model = ResNet50(...).load_state_dict(checkpoint['state_dict'])
model = DDP(...)
如果你想加载它但不是在 DDP 模式下,这有点棘手,因为出于某种原因,他们用额外的后缀module
保存它。 如此处解决,您必须执行以下操作:
state_dict = torch.load(checkpoint['state_dict'])
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.