简体   繁体   English

如何用numba加速xarray计算?

[英]How to speed up xarray calculation with numba?

I am trying to better understand numba decorators especially guvectorize.我试图更好地理解 numba 装饰器,尤其是 guvectorize。

I tried to start here .我试图从这里开始。 In particular at step 15 at the very bottom.特别是在最底部的第 15 步。

I tried to modify this to calculate wind speed.我试图修改它来计算风速。

Here is what I got:这是我得到的:

import numpy as np
import xarray as xr
import datetime
import glob
import dask

import sys
import os
import tempfile

from numba import float64, guvectorize, vectorize, njit

import time as t

@guvectorize(
    "(float64, float64, float64)",
    "(), () -> ()",
    nopython=True,
)
def calcWindspeed_ufunc(u, v, out):
        out = np.sqrt( u**2 + v**2 )


def calcWindspeed(u, v):

    return xr.apply_ufunc(calcWindspeed_ufunc, u, v,
                         input_core_dims=[[],[]],
                         output_core_dims=[[]],
                         # vectorize=True,
                         dask="parallelized",
                         output_dtypes=[u.dtype])


def main():

    nlon = 120
    nlat = 100
    ntime = 3650
    lon = np.linspace(129.4, 153.75, nlon)
    lat = np.linspace(-43.75, -10.1, nlat)
    time = np.linspace(0, 365, ntime)

    #< Create random data
    u = 10 * np.random.rand(len(time), len(lat), len(lon))
    u = xr.Dataset({"u": (["time", "lat", "lon"], u)},coords={"time": time, "lon": lon, "lat": lat})
    u = u.chunk({'time':365})
    u = u['u']
    v = u.copy()


    start = t.time()
    ws_xr = np.sqrt( u**2 + v**2 ).load()
    end = t.time()
    print('It took xarray {} seconds!'.format(end-start))

    start = t.time()
    ws_ufunc = calcWindspeed(u, v).load()
    end = t.time()
    print('It took numba {} seconds!'.format(end-start))

    # Difference of the output
    print( (ws_xr-ws_ufunc).max() )


if __name__ == '__main__':
    import dask.distributed
    import sys

    # Get the number of CPUS in the job and start a dask.distributed cluster
    mem          = 190
    cores        = 4
    memory_limit = '{}gb'.format(int(max(mem/cores, 4)))
    client       = dask.distributed.Client(n_workers=cores, threads_per_worker=1, memory_limit=memory_limit, local_dir=tempfile.mkdtemp())


    #< Print client summary
    print('### Client summary')
    print(client)
    print('\n\n')

    #< Call the main function
    main()

    #< Close the client
    client.shutdown()

This works technically (it runs) but the output is wrong.这在技术上有效(它运行)但输出是错误的。 The difference between both calculations should be close to 0 but in my case is 14.两种计算之间的差异应该接近 0,但在我的情况下是 14。

I don't understand what I am doing wrong.我不明白我做错了什么。

Thank you for your help!感谢您的帮助!

A couple of thoughts:一些想法:

  • There's no need to use numba if it's just calling numpy.如果只是调用 numpy,则无需使用 numba。 Numba runs compiled code, but the current example doesn't really have any code... Numba 运行编译后的代码,但当前示例实际上没有任何代码......
  • If you're using this to run over multiple dimensions, you can do that with xr.apply_ufunc alone如果您使用它来运行多个维度,则可以单独使用xr.apply_ufunc
  • If you want others to engage with the example, could you slim it down to its minimum size?如果你想让其他人参与到这个例子中,你能把它缩小到最小尺寸吗? Currently there's dask, xarray, numba — if you cut those out does the diff still hold?目前有 dask、xarray、numba——如果你去掉它们,差异是否仍然存在?

As a reference, here are some functions I wrote using xarray and numbahttps://github.com/shoyer/numbagg/blob/master/numbagg/moving.py作为参考,这里是我使用 xarray 和 numba 编写的一些函数https://github.com/shoyer/numbagg/blob/master/numbagg/moving.py

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM