简体   繁体   中英

Why is the TPU not recognized on my Google Cloud TPU VM instance?

I have launched a Google Cloud TPU VM instance and installed the latest version of JAX, but it cannot see my TPU. Following the instructions at https://cloud.google.com/tpu/docs/troubleshooting/trouble-jax I encounter the following:

>>> import jax
>>> jax.devices()
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
[CpuDevice(id=0)]
>>> TF_CPP_MIN_LOG_LEVEL=0
>>> jax.devices()
[CpuDevice(id=0)]

All of the Google Search results I have seen for this error suggest installing JAX with CUDA support, but shouldn't that be unnecessary with TPUs?

I recommend to upgrade the jax version.

pip3 install -u jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

EDIT: Actually it seems like this is a bug, refer to:

https://github.com/google/jax/issues/13260

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM