简体   繁体   English

为什么 jax.numpy.dot() 在 CPU 上运行速度比 numpy.dot() 慢?

[英]Why does jax.numpy.dot() run slower than numpy.dot() on CPU?

I want to use JAX to accelerate my numpy code on CPU, later on GPU. Here is my example code running on my local computer (only CPU):我想使用 JAX 在 CPU 上加速我的 numpy 代码,稍后在 GPU 上。这是我在本地计算机(仅 CPU)上运行的示例代码:

import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time

size = 3000

key = random.PRNGKey(0)
x =  random.normal(key, (size,size), dtype=jnp.float64)

start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))

x2=np.random.normal((size,size))

start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))

I got a warning and the time costs are as follows:我收到警告,时间成本如下:

/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: 
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s

Did I do anything wrong here?我在这里做错了什么吗? Should JAX also accelerate numpy code on CPUs? JAX 是否也应该在 CPU 上加速 numpy 代码?

There are probably performance differences between Jax and Numpy, but in the original post, the time differences mostly come down to a mistake in the array creation. Jax 和 Numpy 之间可能存在性能差异,但在原始帖子中,时间差异主要归结为数组创建中的错误。 The array used by Jax has the shape 3000x3000, whereas the array used by Numpy is a 1D array with length 2. The first argument to numpy.random.normal is loc (ie, the mean of the Gaussian from which to sample). Jax 使用的数组具有 3000x3000 的形状,而 Numpy 使用的数组是长度为 2 的一维数组numpy.random.normal的第一个参数是loc (即,要从中采样的高斯均值)。 The keyword argument size= should be used to indicate the shape of the array.关键字参数size=应用于指示数组的形状。

numpy.random.normal(loc=0.0, scale=1.0, size=None)

Once this change is made, the performance between Jax and Numpy is less different.进行此更改后,Jax 和 Numpy 之间的性能差异较小。

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size), dtype=jnp.float64)

start = time.time()
test = jnp.dot(x, x.T).block_until_ready()
print("Time of jnp: {:0.4f} s".format(time.time() - start))

x2 = np.random.normal(size=(size, size)).astype(np.float64)

start = time.time()
test2 = np.dot(x2, x2.T)
print("Time of np: {:0.4f} s".format(time.time() - start))

The output of one run is一次运行的output是

Time of jnp: 2.3315 s
Time of np: 2.8811 s

When measuring timed performance, one should collect multiple runs because a function's performance is a spread of times instead of a single value.在测量定时性能时,应该收集多次运行,因为函数的性能是时间的分布而不是单个值。 This can be done with the Python standard library timeit.timeit function or the %timeit magic in IPython and Jupyter Notebook.这可以通过 Python 标准库timeit.timeit function 或 IPython 和 Jupyter Notebook 中的%timeit魔法来完成。

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
xjnp = jax.random.normal(key, shape=(size, size), dtype=jnp.float64)
xnp = np.random.normal(size=(size, size)).astype(np.float64)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.03 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 3.41 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

xjnp = xjnp.astype(jnp.float32)
xnp = xnp.astype(np.float32)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.05 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 1.73 s ± 383 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

It seems like there is an optimized dot operation for 32-bit floats in Numpy.似乎在 Numpy 中有针对 32 位浮点数的优化点操作。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM