簡體   English   中英

即使有 CPU 分配,JAX GPU memory 的使用情況

[英]JAX GPU memory usage even with CPU allocation

我正在研究 JAX 的 memory 分配以使我的代碼更快,我發現 GPU memory 使用 JAX,即使我將它設置為僅使用 CPU。 我的小代碼是:

import jax
import jax.numpy as jnp

jax.config.update('jax_platform_name', 'cpu')

x=jnp.zeros(10)

for i in range(10000000000):
    1+1

for循環就是看這個程序有沒有使用GPU。

在此之后,我發現它總是使用 GPU 的 253MiB:

303028 哦 20 0 29.5g 377844 294168 R 100.0 0.1 0:08.21 python./test.py

在此處輸入圖像描述

實際上,PID 299133 和 299522 也使用 GPU memory,JAX 設置為使用 CPU。 因此,我不確定我的實際代碼比我的 c++ 代碼慢得多,但我如何設置它根本不使用 GPU?

如果您安裝了支持 gpu 的 jaxlib,JAX 將在導入時預先保留 GPU memory 的 90%,除非您通過適當的環境變量另行指定。 例子是:

  • XLA_PYTHON_CLIENT_PREALLOCATE=false禁用預分配行為
  • XLA_PYTHON_CLIENT_MEM_FRACTION=.XX指定memory的小數部分預分配

有關這方面的更多信息,請參見JAX:GPU Memory 分配

如果你想完全避免 GPU memory 分配,一個選擇是重新安裝 CPU-only jaxlib:

pip install --force-reinstall "jaxlib[cpu]"

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM