繁体   English   中英

如何使 numba(nopython=true) 与元素数量未知的 1D numpy.ndarray 输入一起使用

[英]How to make numba(nopython=true) work with 1D numpy.ndarray input with unknown number of elements

我正在将一个(数学复杂/涉及但操作很少)自制经验分布 class 从 C++/MATLAB(我都有)移植到 Python。

该文件有大约 1100 行代码,包括注释和测试数据,包括

if __name__ == "__main__": 

在文件的底部。

第 83 行有 function 声明: def cdf(self, x):

哪个编译并运行良好,它只是非常慢,所以我想用@numba.jit(nopython=True)编译以使其运行得更快。

但是,编译在文件npts=len(x)的 function (仅前面的注释)第 85 行的最早行之一上终止。

消息以:

[1] During: typing of argument at
C:\Users\kdalbey\Canopy\scripts\empDist.py (85)
--%<-----------------------------------------------------------------

File "Canopy\scripts\empDist.py", line 85

This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class '__main__.empDist'>

现在我真的在文件顶部做了一个import numpy as np但是为了清楚下面的这条消息,我尝试用numpy替换np 但我可能错过了一些。

如果我使用npts=x.size ,我会收到相同的错误消息。

所以我尝试输入x为:

@numba.jit(nopython=True)
def cdf(self, x: numpy.ndarray(dtype=numpy.float64)):

我得到以下错误

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
C:\Users\kdalbey\Canopy\scripts\empDist.py in <module>()
     15 np.set_printoptions(precision=16)
     16 
---> 17 class empDist:
     18     def __init__(self, xdata):
     19         npts=len(xdata)
C:\Users\kdalbey\Canopy\scripts\empDist.py in empDist()
     81 
     82     @numba.jit(nopython=True)
---> 83     def cdf(self, x: np.ndarray(dtype=np.float64)):
     84         # compute the value of cdf at vector of points x
     85         npts = x.size
TypeError: Required argument 'shape' (pos 1) not found

但是我不知道1D numpy.ndarray有多少元素(随意)

我猜我也许可以做一个

@numba.jit(nopython=True)
def cdf(self, x: numpy.ndarray(shape=(), dtype=numpy.float64)):

并且它仅通过 go 回到该错误

[1] During: typing of argument at
C:\Users\kdalbey\Canopy\scripts\empDist.py (85)
--%<-----------------------------------------------------------------
File "Canopy\scripts\empDist.py", line 85
This error may have been caused by the following argument(s):
- argument 0: cannot determine Numba type of <class '__main__.empDist'>

如果我执行npts=int(x.size)npts=numpy.int32(x.size)也是同样的错误,所以我认为问题出在x上。

由于多个问题(从 numba 版本 0.46.0 开始),您的方法存在问题:

  • numpy.ndarray(shape=(), dtype=numpy.float64)真的试图创建一个 NumPy 数组。 将它用作类型提示并不重要。 它仍然被执行(并且失败)。
  • 您应该在jit中使用更合适的(对于 numba)签名,而不是类型提示。 甚至更好:完全省略签名,让 numba 弄清楚。 在大多数情况下,numba 更胜一筹,并且花费更少的精力(如果您不需要限制类型)。
  • 你不能在jit模式下 jit 方法。 更好的方法是制作 function 并从您的方法中调用它。

所以在你的情况下:

import numba as nb

@nb.njit
def _cdf(x):
    # do something with x

class empDist:
    def cdf(self, x):
        result = _cds(x)
        ...

您的示例可能更复杂,但这应该为您提供一个很好的起点。 如果您需要使用实例属性,则只需将它们传递给_cdf (如果 numba 支持它们)。


一般来说,尝试在所有东西上使用 numba 并不是一个好主意。 Numba 的 scope 非常有限,但在它适用的地方,它可以是惊人的。

在您的情况下,您说它很慢。 那么第一步应该是分析你的代码并找出它为什么慢以及在哪里。 然后尝试找出是否可以用更快的方法解决这个瓶颈。 通常问题不在于代码本身,而在于算法/方法。 检查它是否使用次优方法。 如果它不是一个数字繁重的部分,那么使用 numba 可能是有意义的 - 但请注意:通常您根本不需要 numba,因为只需优化 NumPy 部件即可获得足够的性能。

好的...问题是它是一个方法(成员函数),我是从 MrFuppes 那里得到的。 将它隔离在它自己的 function 中,该方法调用的方法效果很好(几乎没有对 function 进行修改,在 numba 之前工作)。

顺便说一句,我将尝试获得批准以发布/发布经验分发代码,但这还有一段路要走。 我也可能想学习 cython 并重新编码以提高 cython 的速度,在我的机器上编译需要 O(秒),因为这些操作在数学上很复杂/涉及但从失败计数的角度来看并没有很多。 与 sklearn.neighbors.kde 相比,我的经验分布要快得多(在 @numba.jit(nopython=True) 编译缓存之后/折扣之后)。 在 windows 上的树冠中运行(使用 numba 0.36.2,因此 np.interp 没有从 numba 中受益)构建这个经验分布需要 5.72e-5 秒,而拟合sklearn kde 需要 2.03e-4 秒,获得 463 点。 此外,它应该很好地扩展到非常多的点。 除了 O(n log(n)) 的快速排序和 O(n) 的插值之外,构造(以及存储对象所需的 memory)成本为 O(n^(1/3))(具有显着系数到 O(n^(1/3))。它具有 PDF、CDF 和逆 CDF 的“简单”分析公式,因此经验分布的评估速度也快了很多。它与 sklearn 具有可比/略好的准确性高斯的 KDE(使用带宽 = (maxx-minx)*0.015 我复制了带宽,所以其他人的代码可能比我的 sklearn kde 更好,显然 kde 的准确性很大程度上取决于带宽,我的经验分布在构建过程中不采用除数据以外的任何参数,它通过算法计算出它需要了解的关于数据的所有信息),并且对于具有有限尾部(例如均匀或指数)的东西具有显着更好的准确性。提高的准确性部分来自于它闷闷不乐与 sklearn kde 相比,振动较小。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM