简体   繁体   中英

How do I get Keras to train a model on a specific GPU?

There is a shared server with 2 GPUs in my institution. Suppose there are two team members each wants to train a model at the same time, then how do they get Keras to train their model on a specific GPU so as to avoid resource conflict?

Ideally, Keras should figure out which GPU is currently busy training a model and then use the other GPU to train the other model. However, this doesn't seem to be the case. It seems that by default Keras only uses the first GPU (since the Volatile GPU-Util of the second GPU is always 0%).

在此处输入图片说明

Possibly duplicate with my previous question

It's a bit more complicated. Keras will the memory in both GPUs althugh it will only use one GPU by default. Check keras.utils.multi_gpu_model for using several GPUs.

I found the solution by choosing the GPU using the environment variable CUDA_VISIBLE_DEVICES.

You can add this manually before importing keras or tensorflow to choose your gpu

os.environ["CUDA_VISIBLE_DEVICES"]="0" # first gpu
os.environ["CUDA_VISIBLE_DEVICES"]="1" # second gpu
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # runs in cpu

To make it automatically, I made a function that parses nvidia-smi and detects automatically which GPU is being already used and sets the appropriate value to the variable.

如果您使用的是训练脚本,您可以在调用脚本之前简单地在命令行中设置它

CUDA_VISIBLE_DEVICES=1 python train.py 

If you want to train models on cloud GPUs (eg GPU instances from AWS), try this library:

!pip install aibro==0.0.45 --extra-index-url https://test.pypi.org/simple

from aibro.train import fit
machine_id = 'g4dn.4xlarge' #instance name on AWS
job_id, trained_model, history = fit(
    model=model,
    train_X=train_X,
    train_Y=train_Y,
    validation_data=(validation_X, validation_Y),
    machine_id=machine_id
)

Tutorial: https://colab.research.google.com/drive/19sXZ4kbic681zqEsrl_CZfB5cegUwuIB#scrollTo=ERqoHEaamR1Y

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM