简体   繁体   English

Xarray 数据集中的重叠块,用于 Kernel 次操作

[英]Overlapping chunks in Xarray dataset for Kernel operations

I try to run a 9x9 pixel kernel across a large satellite image with a custom filter.我尝试使用自定义过滤器在大型卫星图像上运行 9x9 像素 kernel。 One satellite scene has ~ 40 GB and to fit it into my RAM, I'm using xarray s options to chunk my dataset with dask .一个卫星场景大约有 40 GB,为了将其放入我的 RAM,我使用xarray的选项将我的数据集与dask块。

My filter includes a check if the kernel is complete (ie not missing data at the edge of the image).我的过滤器包括检查 kernel 是否完整(即没有丢失图像边缘的数据)。 In that case a NaN is returned to prevent a potential bias (and I don't really care about the edges).在那种情况下,返回 NaN 以防止潜在的偏差(而且我并不真正关心边缘)。 I now realized, that this introduces not only NaNs at the edges of the image (expected behaviour), but also along the edges of each chunk, because the chunks don't overlap.我现在意识到,这不仅在图像边缘引入了 NaN(预期行为),而且还沿着每个块的边缘引入了 NaN,因为块不重叠。 dask provides options to create chunks with an overlap , but are there any comparable capabilities in xarray ? dask提供了创建具有重叠的块的选项,但是xarray中是否有任何类似的功能? I found this issue , but it doesn't seem like there has been any progress in this regard.我发现了这个问题,但在这方面似乎没有任何进展。

Some sample code (shortened version of my original code):一些示例代码(我的原始代码的简化版本):


import numpy as np
import numba
import math
import xarray as xr


@numba.jit("f4[:,:](f4[:,:],i4)", nopython = True)
def water_anomaly_filter(input_arr, window_size = 9):
    # check if window size is odd
    if window_size%2 == 0:
        raise ValueError("Window size must be odd!")
    
    # prepare an output array with NaNs and the same dtype as the input
    output_arr = np.zeros_like(input_arr)
    output_arr[:] = np.nan
    
    # calculate how many pixels in x and y direction around the center pixel
    # are in the kernel
    pix_dist = math.floor(window_size/2-0.5)
    
    # create a dummy weight matrix
    weights = np.ones((window_size, window_size))
    
    # get the shape of the input array
    xn,yn = input_arr.shape
    
    # iterate over the x axis
    for x in range(xn):
        # determine limits of the kernel in x direction
        xmin = max(0, x - pix_dist)
        xmax = min(xn, x + pix_dist+1)
        
        # iterate over the y axis
        for y in range(yn):
            # determine limits of the kernel in y direction
            ymin = max(0, y - pix_dist)
            ymax = min(yn, y + pix_dist+1)

            # extract data values inside the kernel
            kernel = input_arr[xmin:xmax, ymin:ymax]
            
            # if the kernel is complete (i.e. not at image edge...) and it
            # is not all NaN
            if kernel.shape == weights.shape and not np.isnan(kernel).all():
                # apply the filter. In this example simply keep the original
                # value
                output_arr[x,y] = input_arr[x,y]
                
    return output_arr

def run_water_anomaly_filter_xr(xds, var_prefix = "band", 
                                window_size = 9):
    
    variables = [x for x in list(xds.variables) if x.startswith(var_prefix)]
    
    for var in variables[:2]:
        xds[var].values = water_anomaly_filter(xds[var].values, 
                                               window_size = window_size)
    
    return xds

def create_test_nc():

    data = np.random.randn(1000, 1000).astype(np.float32)

    rows = np.arange(54, 55, 0.001)
    cols = np.arange(10, 11, 0.001)

    ds = xr.Dataset(
        data_vars=dict(
            band_1=(["x", "y"], data)
        ),
        coords=dict(
            lon=(["x"], rows),
            lat=(["y"], cols),
        ),
        attrs=dict(description="Testdata"),
    )

    ds.to_netcdf("test.nc")

if __name__ == "__main__":

    # if required, create test data
    create_test_nc()
    
    # import data
    with xr.open_dataset("test.nc",
                         chunks = {"x": 50, 
                                   "y": 50},
                         
                         ) as xds:   

        xds_2 = xr.map_blocks(run_water_anomaly_filter_xr, 
                              xds,
                              template = xds).compute()

        xds_2["band_1"][:200,:200].plot()

This yields: enter image description here这会产生:在此处输入图像描述

You can clearly see the rows and columns of NaNs along the edges of each chunk.您可以清楚地看到每个块边缘的 NaN 行和列。

I'm happy for any suggestions.我很高兴提出任何建议。 I would love to get the overlapping chunks (or any other solution) within xarray , but I'm also open for other solutions.我很想在xarray中获得重叠的块(或任何其他解决方案),但我也对其他解决方案持开放态度。

You can use Dask's map_blocks as follows:你可以使用 Dask 的map_blocks如下:

arr = dask.array.map_overlap(
    water_anomaly_filter, xds.band_1.data, dtype='f4', depth=4, window_size=9
).compute()
da = xr.DataArray(arr, dims=xds.band_1.dims, coords=xds.band_1.coords)

Note that you will likely want to tune depth and window_size for your specific application.请注意,您可能希望为您的特定应用调整depthwindow_size

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM