简体   繁体   中英

Why does jax.numpy.dot() run slower than numpy.dot() on CPU?

I want to use JAX to accelerate my numpy code on CPU, later on GPU. Here is my example code running on my local computer (only CPU):

import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time

size = 3000

key = random.PRNGKey(0)
x =  random.normal(key, (size,size), dtype=jnp.float64)

start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))

x2=np.random.normal((size,size))

start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))

I got a warning and the time costs are as follows:

/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: 
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s

Did I do anything wrong here? Should JAX also accelerate numpy code on CPUs?

There are probably performance differences between Jax and Numpy, but in the original post, the time differences mostly come down to a mistake in the array creation. The array used by Jax has the shape 3000x3000, whereas the array used by Numpy is a 1D array with length 2. The first argument to numpy.random.normal is loc (ie, the mean of the Gaussian from which to sample). The keyword argument size= should be used to indicate the shape of the array.

numpy.random.normal(loc=0.0, scale=1.0, size=None)

Once this change is made, the performance between Jax and Numpy is less different.

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size), dtype=jnp.float64)

start = time.time()
test = jnp.dot(x, x.T).block_until_ready()
print("Time of jnp: {:0.4f} s".format(time.time() - start))

x2 = np.random.normal(size=(size, size)).astype(np.float64)

start = time.time()
test2 = np.dot(x2, x2.T)
print("Time of np: {:0.4f} s".format(time.time() - start))

The output of one run is

Time of jnp: 2.3315 s
Time of np: 2.8811 s

When measuring timed performance, one should collect multiple runs because a function's performance is a spread of times instead of a single value. This can be done with the Python standard library timeit.timeit function or the %timeit magic in IPython and Jupyter Notebook.

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
xjnp = jax.random.normal(key, shape=(size, size), dtype=jnp.float64)
xnp = np.random.normal(size=(size, size)).astype(np.float64)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.03 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 3.41 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

xjnp = xjnp.astype(jnp.float32)
xnp = xnp.astype(np.float32)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.05 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 1.73 s ± 383 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

It seems like there is an optimized dot operation for 32-bit floats in Numpy.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM