简体   繁体   中英

Performance drop when slicing jax.numpy arrays

I have come across some behaviour I don't understand in Jax when trying to do an SVD compression for large arrays. Here is the sample code:

@jit 
def jax_compress(L):
    U, S, _ = jsc.linalg.svd(L, 
    full_matrices = False,
    lapack_driver = 'gesvd',
    check_finite=False,
    overwrite_a=True)

    maxS=jnp.max(S)
    chi = jnp.sum(S/maxS>1E-1)

    return chi, jnp.asarray(U)

Jax/jit give an enormous performance increase over SciPy when considering this snippet of code, but ultimately I want to reduce the dimensionality of U, which I do by wrapping it in the function:

def jax_process(A):

    chi, U = jax_compress(A)
    
    return U[:,0:chi]

This step is unbelievably costly in terms of computation time, more so than the SciPy equivalent, as can be seen in this comparison:

jax 和 scipy 的基准测试

sc_compress and sc_process are the SciPy equivalents to the jax code above. As you can see, it costs almost nothing to slice the arrays in SciPy, but is very expensive when applied to the output of a hit function. Does anyone have some insight to this behaviour?

I did a similar comparison of slicing speeds between JAX and PyTorch. dynamic_slice is substantially faster than regular slicing, but still much slower than the equivalent in torch. As I'm new to JAX, I'm not sure what the reason is, but it could have to do with copying vs referencing as JAX arrays are immutable.

JAX (without @jit)

key = random.PRNGKey(0)
j = random.normal(key, (32, 2, 1024, 1024, 3))
%timeit j[..., 100:600, 100:600, :].block_until_ready()
%timeit dynamic_slice(j, [0, 0, 100, 100, 0], [32, 2, 500, 500, 3]).block_until_ready()
2.78 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
993 µs ± 12.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

PyTorch

t = torch.randn((32, 2, 1024, 1024, 3)).cuda()

%%timeit 
t[..., 100:600, 100:600, :]
torch.cuda.synchronize()
7.63 µs ± 22.7 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

I'm no Jax expert and I'm not sure how it works under the hood, but I ran that snippet and had a look.

I'm pretty sure that the Jax functions within jax_compress (or the effects from the jit decorator) are lazily evaluated, so that they perform the full computation only when you "look inside" the outputs matrices at the end of the calculation and actually ask for concrete numbers (much like python generators do things, and functional languages like Haskell).

I think that the array slicing you're doing at the end is basically a form of this "asking a concrete question" of your matrices.

You can check this by timing your the jax_compress function on its own and after accessing an element of the array:

ti = time.time()
X, U = jax_compress(A)
# very fast
print(f"Compession takes {time.time() - ti} seconds when not peeking")

ti = time.time()
X, U = jax_compress(A)
# much slower
print(U[0][0])
print(f"Compession takes {time.time() - ti} seconds when peeking")

One solution may be to use lax.dynamic_slice or lax.dynamic_update_slice , for which I believe there is a Jax implementation within jax.numpy.lax_numpy . However, depending on your hardware, my hunch is that you will not find much of a speedup, since the scipy implementation of SVD is a wrapper around pretty highly optimised and compiled Fortran code anyway (for a single CPU machine).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM