简体   繁体   中英

Local devices VS non local devices in multi GPU processing

I'm reading JAX documentation on jax.local_devices and in it, it is written:

Like jax.devices() , but only returns devices local to a given process.

And in jax.devices() it is written:

Returns a list of all devices for a given backend.

I don't know what exactly are these local and non-local devices. Could you please elaborate on the difference between these?

This is discussed in JAX's documentation in Using JAX in multi-host and multi-process environments :

A process's local devices are those that it can directly address and launch computations on. For example, on a GPU cluster, each host can only launch computations on the directly attached GPUs. On a Cloud TPU pod, each host can only launch computations on the 8 TPU cores attached directly to that host (see the Cloud TPU System Architecture documentation for more details). You can see a process's local devices via jax.local_devices() .

The global devices are the devices across all processes. A computation can span devices across processes and perform collective operations via the direct communication links between devices, as long as each process launches the computation on its local devices. You can see all available global devices via jax.devices() . A process's local devices are always a subset of the global devices.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM