简体   繁体   中英

DistributedDataParallel with gpu device ID specified in PyTorch

I want to train my model through DistributedDataParallel on a sinle machine that has 8 GPUs. But I want to train my model on four specified GPUs with device IDs 4, 5, 6, 7.

How to specify the GPU device ID for DistributedDataParallel?

I think the world size will be 4 for this case, but what should be the rank in this case?

You can set the environment variable CUDA_VISIBLE_DEVICES . Torch will read this variable and only use the GPUs specified in there. You can either do this directly in your python code like this:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4, 5, 6, 7'

Take care to execute this command before you initialize torch in any way, else the statement will not take effect. The other option would be to set the environment variable temporarily before starting your script in the shell:

CUDA_VISIBLE_DEVICES=4,5,6,7 python your_script.py

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM