繁体   English   中英

Xarray 数据集中的重叠块,用于 Kernel 次操作

[英]Overlapping chunks in Xarray dataset for Kernel operations

我尝试使用自定义过滤器在大型卫星图像上运行 9x9 像素 kernel。 一个卫星场景大约有 40 GB,为了将其放入我的 RAM,我使用xarray的选项将我的数据集与dask块。

我的过滤器包括检查 kernel 是否完整(即没有丢失图像边缘的数据)。 在那种情况下,返回 NaN 以防止潜在的偏差(而且我并不真正关心边缘)。 我现在意识到,这不仅在图像边缘引入了 NaN(预期行为),而且还沿着每个块的边缘引入了 NaN,因为块不重叠。 dask提供了创建具有重叠的块的选项,但是xarray中是否有任何类似的功能? 我发现了这个问题,但在这方面似乎没有任何进展。

一些示例代码(我的原始代码的简化版本):


import numpy as np
import numba
import math
import xarray as xr


@numba.jit("f4[:,:](f4[:,:],i4)", nopython = True)
def water_anomaly_filter(input_arr, window_size = 9):
    # check if window size is odd
    if window_size%2 == 0:
        raise ValueError("Window size must be odd!")
    
    # prepare an output array with NaNs and the same dtype as the input
    output_arr = np.zeros_like(input_arr)
    output_arr[:] = np.nan
    
    # calculate how many pixels in x and y direction around the center pixel
    # are in the kernel
    pix_dist = math.floor(window_size/2-0.5)
    
    # create a dummy weight matrix
    weights = np.ones((window_size, window_size))
    
    # get the shape of the input array
    xn,yn = input_arr.shape
    
    # iterate over the x axis
    for x in range(xn):
        # determine limits of the kernel in x direction
        xmin = max(0, x - pix_dist)
        xmax = min(xn, x + pix_dist+1)
        
        # iterate over the y axis
        for y in range(yn):
            # determine limits of the kernel in y direction
            ymin = max(0, y - pix_dist)
            ymax = min(yn, y + pix_dist+1)

            # extract data values inside the kernel
            kernel = input_arr[xmin:xmax, ymin:ymax]
            
            # if the kernel is complete (i.e. not at image edge...) and it
            # is not all NaN
            if kernel.shape == weights.shape and not np.isnan(kernel).all():
                # apply the filter. In this example simply keep the original
                # value
                output_arr[x,y] = input_arr[x,y]
                
    return output_arr

def run_water_anomaly_filter_xr(xds, var_prefix = "band", 
                                window_size = 9):
    
    variables = [x for x in list(xds.variables) if x.startswith(var_prefix)]
    
    for var in variables[:2]:
        xds[var].values = water_anomaly_filter(xds[var].values, 
                                               window_size = window_size)
    
    return xds

def create_test_nc():

    data = np.random.randn(1000, 1000).astype(np.float32)

    rows = np.arange(54, 55, 0.001)
    cols = np.arange(10, 11, 0.001)

    ds = xr.Dataset(
        data_vars=dict(
            band_1=(["x", "y"], data)
        ),
        coords=dict(
            lon=(["x"], rows),
            lat=(["y"], cols),
        ),
        attrs=dict(description="Testdata"),
    )

    ds.to_netcdf("test.nc")

if __name__ == "__main__":

    # if required, create test data
    create_test_nc()
    
    # import data
    with xr.open_dataset("test.nc",
                         chunks = {"x": 50, 
                                   "y": 50},
                         
                         ) as xds:   

        xds_2 = xr.map_blocks(run_water_anomaly_filter_xr, 
                              xds,
                              template = xds).compute()

        xds_2["band_1"][:200,:200].plot()

这会产生:在此处输入图像描述

您可以清楚地看到每个块边缘的 NaN 行和列。

我很高兴提出任何建议。 我很想在xarray中获得重叠的块(或任何其他解决方案),但我也对其他解决方案持开放态度。

你可以使用 Dask 的map_blocks如下:

arr = dask.array.map_overlap(
    water_anomaly_filter, xds.band_1.data, dtype='f4', depth=4, window_size=9
).compute()
da = xr.DataArray(arr, dims=xds.band_1.dims, coords=xds.band_1.coords)

请注意,您可能希望为您的特定应用调整depthwindow_size

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM