简体   繁体   中英

My python program in numba is not speeding up

My program written for calculating magnetisation takes more time for computation. So I switched to numba .But I could not see any speed increase. Could anyone help me. I am trying to run this code in a 24 core processor.

import time  
import datetime
import numpy as np from math 
import pi
import numba from numba 
import jit,njit,double,vectorize,float64,int64
import time 

 #%% parameters for the calculations
    mu0 = 4e-7 * pi
        h_planck=6.58212e-4# mev*ns
        mub=5.78e-2#meV/T
        g=2
        s=2
        T=2.0 #K
       dt=0.5e-5
       Kb=8.6e-2
       Kjt=1.5
       gamma =(g*mub)/h_planck #1/(T*ns) 
       alpha = 1
       mus=mub*g*s
eA=np.array([-np.sqrt(2.0/3.0),0.0,-np.sqrt(1.0/3.0)])
eB=np.array([-np.sqrt(1.0/6.0),-np.sqrt(1.0/2.0),np.sqrt(1.0/3.0)])
eC=np.array([-np.sqrt(1.0/6.0),np.sqrt(1.0/2.0),np.sqrt(1.0/3.0)])
@njit
def dot(S1,eA,eB,eC):
    result1=0.0
    result2=0.0
    result3=0.0
    for i in range(3):
        result1 += S1[i]*eA[i]
        result2 += S1[i]*eB[i]
        result3 += S1[i]*eC[i]

    return result1,result2,result3
@njit
def jahnteller1(S1):
    global Kjt
    M,N,O=dot(S1,eA,eB,eC)
    P,Q,R=M**5,N**5,O**5
    X=3.0*Kjt*((eA*P+eB*Q+eC*R))
    return X/mus
@njit
def thermal1():
    mu, sigma = 0, 1 # mean and standard deviation
    G = np.random.normal(mu, sigma, 3)
    Hth1=G*np.sqrt((2*alpha*Kb*T)/(gamma*mus*dt))
    return Hth1
#%% calculation of effective field
@njit
def h_eff(B,S1,eH):
    Heff1 = eH*B+jahnteller1(S1)+thermal1()
    return  Heff1
#%% evaluating cross products
@njit
def cross1(S1,heff1):
    result1=np.zeros(3)
    a1, a2, a3 = S1[0], S1[1], S1[2]
    b1, b2, b3 = heff1[0], heff1[1],heff1[2]
    result1[0] = a2 * b3 - a3 * b2
    result1[1] = a3 * b1 - a1 * b3
    result1[2] = a1 * b2 - a2 * b1
    return result1
@njit
def cross2(S1,X):
    result2=np.zeros(3)
    a1, a2, a3 = S1[0],S1[1],S1[2]
    c1, c2, c3 = X[0],X[1],X[2]
    result2[0] = a2 * c3 - a3 * c2
    result2[1] = a3 * c1 - a1 * c3
    result2[2] = a1 * c2 - a2 * c1
    return result2
#%% Main function to calculate the Spin S1 by calculating the effective field
 @njit
def llg(S1,dt, B,eH):
    global gamma,alpha
    N_init = int(5)
    for i in range(N_init):
        heff1 = h_eff(B,S1,eH)
        X=cross1(S1,heff1)
        Y=cross2(S1,X)
        dS1dt = - gamma/(1+alpha**2) * X \
           - alpha*gamma/(1+alpha**2) * Y
        S1 += dt * dS1dt
        normS1 = np.sqrt(S1[0]*S1[0]+S1[1]*S1[1]+S1[2]*S1[2])
        S1 = S1/normS1
    Savg=np.array([0.0,0.0,0.0])
    Navg=N_init*10
    for i in range(Navg):
        heff1 = h_eff(B,S1,eH)
        X=cross1(S1,heff1)
        Y=cross2(S1,X)
        dS1dt = - gamma/(1+alpha**2) * X \
           - alpha*gamma/(1+alpha**2) * Y
        S1 += dt * dS1dt
        normS1 = np.sqrt(S1[0]*S1[0]+S1[1]*S1[1]+S1[2]*S1[2])
        S1 = S1/normS1
        Savg=Savg+S1
    Savg=Savg/Navg
    return Savg  
#%% calculating dot product
@njit
def dott(S1,K):
    result=0.0
    for i in range(3):
        result += S1[i]*K[i]
    return result


 #%% initialising magn
        magn=np.zeros([25,3]) 
        Th=[]
        Ph=[]
        B=5.0
        theta=np.linspace(0.0,np.pi,5)
        phi=np.linspace(0.0,2*np.pi,5)
    for i in range(len(phi)):
        for j in range(len(theta)):
            M,N=phi[i],theta[j]
            Th.append(N)
            Ph.append(M)

#%% calling the main fuction
for i in range(25):
    magn[i][0]=Ph[i]
    magn[i][1]=Th[i]
    eH=np.array([np.sin(Th[i])*np.cos(Ph[i]),np.sin(Th[i])*np.sin(Ph[i]),np.cos(Th[i])])
    normH = np.sqrt(eH[0]*eH[0]+eH[1]*eH[1]+eH[2]*eH[2])
    eH=eH/normH
    S1=np.array([np.sin(Th[i])*np.cos(Ph[i]),np.sin(Th[i])*np.sin(Ph[i]),np.cos(Th[i])])
    S1=llg(S1,dt,B,eH)
    K=eH*B
    Z=dott(S1,K)
    E=-Z*g*mub*s
    magn[i][2]=E

#%% printing magn
print(magn)
%timeit magn

A few things to note:

  1. You don't seem to be timing the whole operation. I'm unsure what the last %timeit expression gives you
  2. On my machine, running the code as you have given runs in about 2.5 seconds
  3. Note that the first call to a Numba function is very slow as a compiler is converting the code into llvm code. If you change @njit to @njit(cache=True) then the result gets cached and future runs do not cause compilation (until you change the function). When I do this on my machine the first run still takes 2.5 seconds but the second run finished in 0.12 seconds.
  4. None of this holds a candle to running this function in pure python, which only takes 0.06 seconds.

Why?

The biggest reason for that in your code seems to be that you have a number of small functions which you are calling from Python inside a loop. Calling a numba function incurs overhead (which I think is worse than the overhead of calling a normal python function since there's type checking to be done). As such if your jitted functions are simple, the benefit of using them becomes negligible (or worse, you get a penalty for it). If you can change your code so that the entire logic (ie the main loop) are also inside a numba function it may become faster than pure python.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM