简体   繁体   中英

Using `dask.array.map_block()` to parallelize line fitting on a 3-D `dask.array`

I have a series of N images that are recorded at different times. I have stacked the images into a 3-D dask array and rechunked them along the time axis. I would now like to perform a linear fit at each pixel position across the image, but I am running into the following error when using da.map_blocks as I try to scale up: TypeError: expected 1D or 2D array for y

I found one other post, applying-a-function-along-an-axis-of-a-dask-array , related to this but it didn't address an issue with specifically setting the chunk size. When using da.apply_along_axis I found an issue similar to the one reported in dask-performance-apply-along-axis wherein only one CPU seems to be utilized during the computation (even for chunked data).

MWE: Works properly

import dask.array as da
import numpy as np
import matplotlib.pyplot as plt
plt.style.use('ggplot')

def f(y, args, axis=None):
    return np.polyfit(args[0], y.squeeze(), args[1])[:, None, None]

deg = 1
nsamp=20*10*10
shape=(20,10,10)
chunk_size=(20,1,1)
a = da.linspace(1, nsamp, nsamp).reshape(shape)
chunked = a.rechunk(chunk_size)
times = da.linspace(1, shape[0], shape[0])
results = chunked.map_blocks(f, chunks=(20,1,1), args=[times, deg], dtype='float').compute()
m_fit = results[0]
b_fit = results[1]

# Plot a few fits to visually examine them
fig, ax = plt.subplots(nrows=1, ncols=1)
for (x,y) in zip([1,9], [1,9]):
    ax.scatter(times, chunked[:,x,y])
    ax.plot(times, np.polyval([m_fit[x, y], b_fit[x,y]], times))

The array, chunked , looks like this:

The resulting plot looks like this,

Which is exactly what I would expect and so all is well, However. the issue arises whenever I try to use a chunksize larger than one.

MWE: Raises TypeError

nsamp=20*10*10
shape=(20,10,10)
chunk_size=(20,5,5) # Chunking the data now
a = da.linspace(1,nsamp, nsamp).reshape(shape)
chunked = a.rechunk(chunk_size)
times = da.linspace(1, shape[0], shape[0])
results = chunked.map_blocks(f, chunks=(20,1,1), args=[times, 1], dtype='float') # error

Does anyone have any ideas as to what is happening here?

It looks like maybe your function expects single-dimensional inputs. I wonder if there is a way that you can write a Python function that wraps your function and handles the unpacking and then repacking of one-dimensional inputs. If you can get that function to work on a single numpy array of shape (20, 2, 2) for example then you can probably use Dask to then apply that function across many similarly sized chunks

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM