[英]JAX GPU memory usage even with CPU allocation
我正在研究 JAX 的 memory 分配以使我的代碼更快,我發現 GPU memory 使用 JAX,即使我將它設置為僅使用 CPU。 我的小代碼是:
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')
x=jnp.zeros(10)
for i in range(10000000000):
1+1
for循環就是看這個程序有沒有使用GPU。
在此之后,我發現它總是使用 GPU 的 253MiB:
303028 哦 20 0 29.5g 377844 294168 R 100.0 0.1 0:08.21 python./test.py
和
實際上,PID 299133 和 299522 也使用 GPU memory,JAX 設置為使用 CPU。 因此,我不確定我的實際代碼比我的 c++ 代碼慢得多,但我如何設置它根本不使用 GPU?
如果您安裝了支持 gpu 的 jaxlib,JAX 將在導入時預先保留 GPU memory 的 90%,除非您通過適當的環境變量另行指定。 例子是:
XLA_PYTHON_CLIENT_PREALLOCATE=false
禁用預分配行為XLA_PYTHON_CLIENT_MEM_FRACTION=.XX
指定memory的小數部分預分配有關這方面的更多信息,請參見JAX:GPU Memory 分配。
如果你想完全避免 GPU memory 分配,一個選擇是重新安裝 CPU-only jaxlib:
pip install --force-reinstall "jaxlib[cpu]"
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.