简体   繁体   中英

How to isolate a subset of an array in a JAX JIT-compilable way?

I have some datasets I want to analyze, where the data for each dataset is stored in two arrays: one for the x values and one for the y. These arrays are generally very large (~100,000 data points or more). These data are run through several (JIT-compatible) functions which perform various calculations on the y-values. These calculations scale poorly with large datasets, so I want to create a function which isolates a subset of the data before passing it to the rest of my functions (thereby enabling me to check the entire dataset as a series of subsets, which should save time). I have a working function for isolating a subset of the data, however it is not JIT-compatible, which means I have to unjit the rest of my functions (since they're all called from some jitted master function).

I've written up a simple dummy script which demonstrates my problem:

import jax
import numpy as np
import matplotlib.pyplot as plt


# enable 64-bit floating point representation in JAX
jax.config.update("jax_enable_x64", True)


def isolate_subset(indices: list) -> tuple:
    
    # compute difference between successive x values
    delta = x[1] - x[0]
    
    # get indices of start and stop of subset
    i_start = jax.numpy.min(jax.numpy.where(jax.numpy.isclose(x, x[indices[0]], atol=delta/2))[0])
    i_stop = jax.numpy.max(jax.numpy.where(jax.numpy.isclose(x, x[indices[1]-1], atol=delta/2))[0])
    
    # get subset of x within specified bounds
    x_subset = jax.numpy.array(x[i_start:i_stop])
    y_subset = jax.numpy.array(y[i_start:i_stop])
    
    # return subset of x within specified bounds
    return x_subset, y_subset


# @jax.jit
def master_function(indices: indices) -> tuple:
    
    # get subset of x and y within bounds params[0] to params[1]
    x_subset, y_subset = isolate_subset(indices)
    
    # call some other functions to do some calculations
    
    # return data subset
    return x_subset, y_subset
    
    
    

# create some arbitrary x and y values
x = np.linspace(0, 10, 1000)
y = np.sin(2*np.pi*x)

# break data down into uniform subsets
subset_indices = np.arange(0, len(x)+1, int(len(x)/5))

# for each subset
for i in range(0, len(subset_indices)-1):
    
    # get indices of subset
    i_sub = subset_indices[i:i+2]
    
    # analyse subset of data
    x_subset, y_subset = master_function(i_sub)
    
    # create figure
    fig, ax = plt.subplots(tight_layout=True, dpi=200)
    
    # plot original (x, y) and subset of (x, y)
    ax.plot(x, y, "k--", zorder=0, label="original")
    ax.plot(x_subset, y_subset, "r-", zorder=1, label="subset")
    
    # label plot
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.legend()

plt.show()

This script works when master_function() is not jitted, but when I try to JIT it I get the following error:

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)>
The error occurred while tracing the function master_function at untitled0.py:35 for jit. This concrete value was not available in Python because it depends on the value of the argument 'indices'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError

I thought if I used jax.numpy instead of standard Python/NumPy, I wouldn't have any issues jitting. I've been using Python for years, but I've only been using JAX for about a month or so now, and I'm still figuring out its nuances. How can I make isolate_subset() JIT-compatible?

To be compatible with JIT, arrays in JAX must have static shapes: ie the size of an array cannot depend on the values within another array. With this in mind, it is impossible to do what you are asking, because your procedure creates subset arrays whose size depends on the values in the input arrays, so such a procedure can never be done in a JIT-compatible way.

The typical strategy in this case is to find a way to pre-define the expected size of the subset arrays, and pad them with extra values if necessary.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM