I want to use JAX to accelerate my numpy code on CPU, later on GPU. Here is my example code running on my local computer (only CPU):
import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time
size = 3000
key = random.PRNGKey(0)
x = random.normal(key, (size,size), dtype=jnp.float64)
start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))
x2=np.random.normal((size,size))
start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))
I got a warning and the time costs are as follows:
/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130:
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s
Did I do anything wrong here? Should JAX also accelerate numpy code on CPUs?
There are probably performance differences between Jax and Numpy, but in the original post, the time differences mostly come down to a mistake in the array creation. The array used by Jax has the shape 3000x3000, whereas the array used by Numpy is a 1D array with length 2. The first argument to numpy.random.normal
is loc
(ie, the mean of the Gaussian from which to sample). The keyword argument size=
should be used to indicate the shape of the array.
numpy.random.normal(loc=0.0, scale=1.0, size=None)
Once this change is made, the performance between Jax and Numpy is less different.
import time
import jax
import jax.numpy as jnp
import numpy as np
size = 3000
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size), dtype=jnp.float64)
start = time.time()
test = jnp.dot(x, x.T).block_until_ready()
print("Time of jnp: {:0.4f} s".format(time.time() - start))
x2 = np.random.normal(size=(size, size)).astype(np.float64)
start = time.time()
test2 = np.dot(x2, x2.T)
print("Time of np: {:0.4f} s".format(time.time() - start))
The output of one run is
Time of jnp: 2.3315 s
Time of np: 2.8811 s
When measuring timed performance, one should collect multiple runs because a function's performance is a spread of times instead of a single value. This can be done with the Python standard library timeit.timeit
function or the %timeit
magic in IPython and Jupyter Notebook.
import time
import jax
import jax.numpy as jnp
import numpy as np
size = 3000
key = jax.random.PRNGKey(0)
xjnp = jax.random.normal(key, shape=(size, size), dtype=jnp.float64)
xnp = np.random.normal(size=(size, size)).astype(np.float64)
%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.03 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.dot(xnp, xnp.T)
# 3.41 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xjnp = xjnp.astype(jnp.float32)
xnp = xnp.astype(np.float32)
%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.05 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.dot(xnp, xnp.T)
# 1.73 s ± 383 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
It seems like there is an optimized dot operation for 32-bit floats in Numpy.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.