简体   繁体   中英

Is it possible to pass a class method reference to a njit function?

I tried to improve the computation time of some of my code. So I use the njit decorator of numba module to do that. In this example:

import numpy as np
from numba import jitclass, jit, njit
from numba import int32, float64
import matplotlib.pyplot as plt
import time

spec = [('V_init' ,float64),
        ('a' ,float64),
        ('b' ,float64),
        ('g',float64),
        ('dt' ,float64),
        ('NbODEs',int32),
        ('dydx' ,float64[:]),
        ('time' ,float64[:]),
        ('V' ,float64[:]),
        ('W' ,float64[:]),
        ('y'    ,float64[:]) ]

@jitclass(spec, )
class FHNfunc:
    def __init__(self,):
        self.V_init = .04
        self.a= 0.25
        self.b=0.001
        self.g = 0.003
        self.dt = .01
        self.NbODEs = 2
        self.dydx    =np.zeros(self.NbODEs  )
        self.y    =np.zeros(self.NbODEs  )

    def Eul(self,):
        self.deriv()
        self.y += (self.dydx * self.dt)

    def deriv(self,):
        self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
        self.dydx[1]= self.b * self.y[0] - self.g * self.y[1]
        return



@njit(fastmath=True)
def solve1(FH1,FHEuler,tp):
    V = np.zeros(len(tp), )
    W = np.zeros(len(tp), )

    for idx, t in enumerate(tp):
        FHEuler
        V[idx] = FH1.y[0]
        W[idx] = FH1.y[1]
    return V,W


if __name__ == "__main__":

    FH1 = FHNfunc()
    FHEuler = FH1.Eul

    dt = .01
    tp = np.linspace(0, 1000, num = int((1000)/dt))

    t0 = time.time()
    [V1,W1] = solve1(FH1,FHEuler,tp)
    print(time.time()- t0)
    plt.figure()
    plt.plot(tp,V1)
    plt.plot(tp,W1)
    plt.show()

I would like to pass a reference to a class method named FHEuler = FH1.Eul , but it crashes and gives me this error

This error may have been caused by the following argument(s):
- argument 1: cannot determine Numba type of <class 'method'>

So is it possible to pass a reference to a njit function? or does it exist a workaround?

Numba can not handle the function as argument. An alternative way is to compile the function before and than use an inner function to handle the other arguments and return the inner function with compiled input function ran inside it. Try this please:

def solve1(FH1,FHEuler,tp):
    FHEuler_f = njit(FHEuler)
    @njit(fastmath=True)
    def inner(FH1_x, tp_x):
        V = np.zeros(len(tp_x), )
        W = np.zeros(len(tp_x), )
        for idx, t in enumerate(tp_x):
            FHEuler_f
            V[idx] = FH1_x.y[0]
            W[idx] = FH1_x.y[1]
        return V,W
    return inner(FH1, tp)

Passing function might not be necessary. This one looks work

@njit(fastmath=True)
def solve1(FH1,tp):
    FHEuler = FH1.Eul
    V = np.zeros(len(tp), )
    W = np.zeros(len(tp), )

    for idx, t in enumerate(tp):
        FHEuler()
        V[idx] = FH1.y[0]
        W[idx] = FH1.y[1]
    return V,W

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM