简体   繁体   中英

How to specify or set a variable to a GPU device

I'm new to JAX and I want to work with multiple GPUs. So far two GPUs (0 and 1) are visible to my JAX.

import jax
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
print(jax.local_devices())
>>>
# prints: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]

When I create a NumPy object it will always be in GPU device 0 which I assume is the default one.

nmp = jax.numpy.ones(4)
print(nmp.device())
>>>
# Prints: gpu:0

How can I send my variable nmp to be stored in gpu:1 , the other GPU?

Use .device_put()

import jax
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
devices = jax.local_devices()
print(devices) # >>> [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]

nmp = jax.numpy.ones(4)
print(nmp.device()) # >>> gpu:0

nmp = jax.device_put(nmp, jax.devices()[1])
print(nmp.device()) # >>> gpu:1

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM