I tried to improve the computation time of some of my code. So I use the njit
decorator of numba
module to do that. In this example:
import numpy as np
from numba import jitclass, jit, njit
from numba import int32, float64
import matplotlib.pyplot as plt
import time
spec = [('V_init' ,float64),
('a' ,float64),
('b' ,float64),
('g',float64),
('dt' ,float64),
('NbODEs',int32),
('dydx' ,float64[:]),
('time' ,float64[:]),
('V' ,float64[:]),
('W' ,float64[:]),
('y' ,float64[:]) ]
@jitclass(spec, )
class FHNfunc:
def __init__(self,):
self.V_init = .04
self.a= 0.25
self.b=0.001
self.g = 0.003
self.dt = .01
self.NbODEs = 2
self.dydx =np.zeros(self.NbODEs )
self.y =np.zeros(self.NbODEs )
def Eul(self,):
self.deriv()
self.y += (self.dydx * self.dt)
def deriv(self,):
self.dydx[0]= self.V_init - self.y[0] *(self.a-(self.y[0]))*(1-(self.y[0]))-self.y[1]
self.dydx[1]= self.b * self.y[0] - self.g * self.y[1]
return
@njit(fastmath=True)
def solve1(FH1,FHEuler,tp):
V = np.zeros(len(tp), )
W = np.zeros(len(tp), )
for idx, t in enumerate(tp):
FHEuler
V[idx] = FH1.y[0]
W[idx] = FH1.y[1]
return V,W
if __name__ == "__main__":
FH1 = FHNfunc()
FHEuler = FH1.Eul
dt = .01
tp = np.linspace(0, 1000, num = int((1000)/dt))
t0 = time.time()
[V1,W1] = solve1(FH1,FHEuler,tp)
print(time.time()- t0)
plt.figure()
plt.plot(tp,V1)
plt.plot(tp,W1)
plt.show()
I would like to pass a reference to a class method named FHEuler = FH1.Eul
, but it crashes and gives me this error
This error may have been caused by the following argument(s):
- argument 1: cannot determine Numba type of <class 'method'>
So is it possible to pass a reference to a njit function? or does it exist a workaround?
Numba can not handle the function as argument. An alternative way is to compile the function before and than use an inner function to handle the other arguments and return the inner function with compiled input function ran inside it. Try this please:
def solve1(FH1,FHEuler,tp):
FHEuler_f = njit(FHEuler)
@njit(fastmath=True)
def inner(FH1_x, tp_x):
V = np.zeros(len(tp_x), )
W = np.zeros(len(tp_x), )
for idx, t in enumerate(tp_x):
FHEuler_f
V[idx] = FH1_x.y[0]
W[idx] = FH1_x.y[1]
return V,W
return inner(FH1, tp)
Passing function might not be necessary. This one looks work
@njit(fastmath=True)
def solve1(FH1,tp):
FHEuler = FH1.Eul
V = np.zeros(len(tp), )
W = np.zeros(len(tp), )
for idx, t in enumerate(tp):
FHEuler()
V[idx] = FH1.y[0]
W[idx] = FH1.y[1]
return V,W
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.