[英]Is PyTorch DistributedDataParallel with different GPU speeds syncing weights?
[英]Is there a way to replace the 'allreduce_hook' used for DDP(DistributedDataParallel) in Pytorch?
我知道 Pytorch DDP 使用 'allreduce_hook' 作为默认通信挂钩。 有没有办法用“quantization_pertensor_hook”或“powerSGD_hook”替换这个默认挂钩。 有一个官方的Pytorch 文档介绍了 DDP 通信钩子,但我仍然对如何在实践中做到这一点感到困惑。
这就是我启动进程组并创建 DDP model 的方式。
import torch.distributed as dist
import torch.nn as nn
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[0])
有没有办法根据这段代码声明我想要的钩子?
这可以完成这项工作
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[0])
state = powerSGD.PowerSGDState(process_group=None, matrix_approximation_rank=1, start_powerSGD_iter=10, min_compression_rate=0.5)
model.register_comm_hook(state, powerSGD.powerSGD_hook)
...
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.