简体   繁体   English

有没有办法替换 Pytorch 中用于 DDP(DistributedDataParallel) 的“allreduce_hook”?

[英]Is there a way to replace the 'allreduce_hook' used for DDP(DistributedDataParallel) in Pytorch?

I know that Pytorch DDP uses 'allreduce_hook' as the default communication hook.我知道 Pytorch DDP 使用 'allreduce_hook' 作为默认通信挂钩。 Is there a way to replace this default hook with 'quantization_pertensor_hook' or 'powerSGD_hook'.有没有办法用“quantization_pertensor_hook”或“powerSGD_hook”替换这个默认挂钩。 There is an official Pytorch documentation introducing the DDP communication hooks, but I still got confused about how to do this in practice.有一个官方的Pytorch 文档介绍了 DDP 通信钩子,但我仍然对如何在实践中做到这一点感到困惑。

This is how I initiate the process group and create the DDP model.这就是我启动进程组并创建 DDP model 的方式。

import torch.distributed as dist
import torch.nn as nn

dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[0])

Is there any way to declare the hook that I want based on this code?有没有办法根据这段代码声明我想要的钩子?

This could do the job这可以完成这项工作


dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[0])

state = powerSGD.PowerSGDState(process_group=None, matrix_approximation_rank=1, start_powerSGD_iter=10, min_compression_rate=0.5)
model.register_comm_hook(state, powerSGD.powerSGD_hook)
...

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 具有不同 GPU 速度的 PyTorch DistributedDataParallel 是否同步权重? - Is PyTorch DistributedDataParallel with different GPU speeds syncing weights? 如何在 Pytorch1.1 和 DistributedDataParallel() 中计算米? - How to calculate meters in Pytorch1.1 & DistributedDataParallel()? 在 PyTorch 中使用分布式数据并行 (DDP) 时,在训练期间检查点的正确方法是什么? - What is the proper way to checkpoint during training when using distributed data parallel (DDP) in PyTorch? 将 ddp 后端与 PyTorch Lightning 一起使用时,在整个验证集上进行验证 - Validate on entire validation set when using ddp backend with PyTorch Lightning 如何在多个 GPU 的 Pytorch 示例中利用 DistributedDataParallel 的世界大小参数? - How to leverage the world-size parameter for DistributedDataParallel in Pytorch example for multiple GPUs? 用 pytorch 张量中的子集替换特定列的有效方法是什么 - What is an efficient way to replace specific columns with a subset in pytorch tensor 如何创建一个有条件的 PyTorch 钩子? - How to create a PyTorch hook with conditions? 如何在 A40 GPU 上运行 PyTorch 而没有错误(也使用 DDP)? - How does one run PyTorch on a A40 GPU without errors (with DDP too)? 在 PyTorch 的数据并行方法中使用钩子方法 - Using hook method in a data parallelism approach in PyTorch 这是创建 PyTorch 标量的方法吗? - Is this the way to create a PyTorch scalar?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM