简体   繁体   English

使用 scipy curve_fit 与 dask/xarray

[英]using scipy curve_fit with dask/xarray

I'm trying to use scipy.optimize.curve_fit on a large latitude/longitude/time xarray using dask.distributed as computing backend.我正在尝试使用 dask.distributed 作为计算后端在大纬度/经度/时间 xarray 上使用 scipy.optimize.curve_fit。

The idea is to run an individual data fitting for every (latitude, longitude) using the time series.这个想法是使用时间序列为每个(纬度,经度)运行单独的数据。

All of this runs fine outside xarray/dask.所有这些在 xarray/dask 之外运行良好。 I tested it using the time series of a single location passed as a pandas dataframe.我使用作为 pandas dataframe 传递的单个位置的时间序列对其进行了测试。 However, if I try to run the same process on the same (latitude, longitude) directly on the xarray, the curve_fit operation returns the initial parameters.但是,如果我尝试直接在 xarray 上的相同(纬度、经度)上运行相同的过程,curve_fit 操作将返回初始参数。

I am performing this operation using xr.apply_ufunc like so (here I'm providing only the code that is strictly relevant to the problem):我正在使用xr.apply_ufunc执行此操作(这里我只提供与问题严格相关的代码):

    # function to perform the fit
    def _fit_rti_curve(data, data_rti, fit, loc=False):
        fit_func, linearize, find_init_params = _get_fit_functions(fit)
        # remove nans
        x, y = _filter_nodata(data_rti, data)
        # remove outliers
        x, y = _filter_for_outliers(x, y, linearize=linearize)

        # find a first guess for maximum achieveable value
        yscale = np.max(y) * 1.05
        # find a first guess for the other parameters
        # here loc can be manually passed if you have a good estimation
        init_parms = find_init_params(x, y, yscale, loc=loc, linearize=linearize)
        # fit the curve and return parameters
        parms = curve_fit(fit_func, x, y, p0=init_parms, maxfev=10000)
        parms = parms[0]
        return parms

    # shell around _fit_rti_curve
    def find_rti_func_parms(data, rti, fit):
        # sort and fit highest n values
        top_data = np.sort(data)
        top_data = top_data[-len(rti):]

        # convert to float64 if needed
        top_data = top_data.astype(np.float64)
        rti = rti.astype(np.float64)

        # run the fit
        parms = _fit_rti_curve(top_data, rti, fit, loc=0) #TODO maybe add function to allow a free loc
        return parms


    # call for the apply_ufunc
    # `fit` is a string that defines the distribution type
    # `rti` is an array for the x values
    parms_data = xr.apply_ufunc(
        find_rti_func_parms,
        xr_obj,
        input_core_dims=[['time']],
        output_core_dims=[[fit + ' parameters']],
        output_sizes = {fit + ' parameters': len(signature(fit_func).parameters) - 1},
        vectorize=True,
        kwargs={'rti':return_time_interval, 'fit':fit},
        dask='parallelized',
        output_dtypes=['float64']
    )

My guess would be that is a problem related to threading, or at least some shared memory space that is not properly passed between workers and scheduler.我的猜测是,这是与线程相关的问题,或者至少是一些共享的 memory 空间在工作程序和调度程序之间没有正确传递。 However, I am just not knowledgeable enough to test this within dask.但是,我只是没有足够的知识来测试这个。

Any idea on this problem?对这个问题有任何想法吗?

This previous answer might be helpful?这个先前的答案可能会有所帮助? It's using numpy.polyfit but I think the general approach should be similar.它使用numpy.polyfit但我认为一般方法应该相似。

Applying numpy.polyfit to xarray Dataset 将 numpy.polyfit 应用于 xarray 数据集

Also, I haven't tried it but xr.polyfit() just got merged recently.另外,我还没有尝试过,但xr.polyfit()最近刚刚合并。 Could also be something to look into.也可能是要研究的东西。 http://xarray.pydata.org/en/stable/generated/xarray.DataArray.polyfit.html#xarray.DataArray.polyfit http://xarray.pydata.org/en/stable/generated/xarray.DataArray.polyfit.html#xarray.DataArray.polyfit

You should have a look at this issue https://github.com/pydata/xarray/issues/4300 I had the same problem and I solved using apply_ufunc.你应该看看这个问题https://github.com/pydata/xarray/issues/4300我有同样的问题,我用 apply_ufunc 解决了。 It is not optimized, since it has to perform rechunking operations, but it works!它没有优化,因为它必须执行重新分块操作,但它可以工作! I've created a GitHub Gist for it https://gist.github.com/clausmichele/8350e1f7f15e6828f29579914276de71我为它创建了一个 GitHub Gist https://gist.github.com/clausmichele/8276871526

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM