繁体   English   中英

为什么 jax.numpy.dot() 在 CPU 上运行速度比 numpy.dot() 慢?

[英]Why does jax.numpy.dot() run slower than numpy.dot() on CPU?

我想使用 JAX 在 CPU 上加速我的 numpy 代码,稍后在 GPU 上。这是我在本地计算机(仅 CPU)上运行的示例代码:

import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time

size = 3000

key = random.PRNGKey(0)
x =  random.normal(key, (size,size), dtype=jnp.float64)

start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))

x2=np.random.normal((size,size))

start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))

我收到警告,时间成本如下:

/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: 
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s

我在这里做错了什么吗? JAX 是否也应该在 CPU 上加速 numpy 代码?

Jax 和 Numpy 之间可能存在性能差异,但在原始帖子中,时间差异主要归结为数组创建中的错误。 Jax 使用的数组具有 3000x3000 的形状,而 Numpy 使用的数组是长度为 2 的一维数组numpy.random.normal的第一个参数是loc (即,要从中采样的高斯均值)。 关键字参数size=应用于指示数组的形状。

numpy.random.normal(loc=0.0, scale=1.0, size=None)

进行此更改后,Jax 和 Numpy 之间的性能差异较小。

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size), dtype=jnp.float64)

start = time.time()
test = jnp.dot(x, x.T).block_until_ready()
print("Time of jnp: {:0.4f} s".format(time.time() - start))

x2 = np.random.normal(size=(size, size)).astype(np.float64)

start = time.time()
test2 = np.dot(x2, x2.T)
print("Time of np: {:0.4f} s".format(time.time() - start))

一次运行的output是

Time of jnp: 2.3315 s
Time of np: 2.8811 s

在测量定时性能时,应该收集多次运行,因为函数的性能是时间的分布而不是单个值。 这可以通过 Python 标准库timeit.timeit function 或 IPython 和 Jupyter Notebook 中的%timeit魔法来完成。

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
xjnp = jax.random.normal(key, shape=(size, size), dtype=jnp.float64)
xnp = np.random.normal(size=(size, size)).astype(np.float64)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.03 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 3.41 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

xjnp = xjnp.astype(jnp.float32)
xnp = xnp.astype(np.float32)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.05 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 1.73 s ± 383 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

似乎在 Numpy 中有针对 32 位浮点数的优化点操作。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM