繁体   English   中英

我在 numba 中的 python 程序没有加速

[英]My python program in numba is not speeding up

我为计算磁化而编写的程序需要更多的计算时间。 所以我切换到numba 。但我看不到任何速度增加。 谁能帮帮我。 我正在尝试在 24 核处理器中运行此代码。

import time  
import datetime
import numpy as np from math 
import pi
import numba from numba 
import jit,njit,double,vectorize,float64,int64
import time 

 #%% parameters for the calculations
    mu0 = 4e-7 * pi
        h_planck=6.58212e-4# mev*ns
        mub=5.78e-2#meV/T
        g=2
        s=2
        T=2.0 #K
       dt=0.5e-5
       Kb=8.6e-2
       Kjt=1.5
       gamma =(g*mub)/h_planck #1/(T*ns) 
       alpha = 1
       mus=mub*g*s
eA=np.array([-np.sqrt(2.0/3.0),0.0,-np.sqrt(1.0/3.0)])
eB=np.array([-np.sqrt(1.0/6.0),-np.sqrt(1.0/2.0),np.sqrt(1.0/3.0)])
eC=np.array([-np.sqrt(1.0/6.0),np.sqrt(1.0/2.0),np.sqrt(1.0/3.0)])
@njit
def dot(S1,eA,eB,eC):
    result1=0.0
    result2=0.0
    result3=0.0
    for i in range(3):
        result1 += S1[i]*eA[i]
        result2 += S1[i]*eB[i]
        result3 += S1[i]*eC[i]

    return result1,result2,result3
@njit
def jahnteller1(S1):
    global Kjt
    M,N,O=dot(S1,eA,eB,eC)
    P,Q,R=M**5,N**5,O**5
    X=3.0*Kjt*((eA*P+eB*Q+eC*R))
    return X/mus
@njit
def thermal1():
    mu, sigma = 0, 1 # mean and standard deviation
    G = np.random.normal(mu, sigma, 3)
    Hth1=G*np.sqrt((2*alpha*Kb*T)/(gamma*mus*dt))
    return Hth1
#%% calculation of effective field
@njit
def h_eff(B,S1,eH):
    Heff1 = eH*B+jahnteller1(S1)+thermal1()
    return  Heff1
#%% evaluating cross products
@njit
def cross1(S1,heff1):
    result1=np.zeros(3)
    a1, a2, a3 = S1[0], S1[1], S1[2]
    b1, b2, b3 = heff1[0], heff1[1],heff1[2]
    result1[0] = a2 * b3 - a3 * b2
    result1[1] = a3 * b1 - a1 * b3
    result1[2] = a1 * b2 - a2 * b1
    return result1
@njit
def cross2(S1,X):
    result2=np.zeros(3)
    a1, a2, a3 = S1[0],S1[1],S1[2]
    c1, c2, c3 = X[0],X[1],X[2]
    result2[0] = a2 * c3 - a3 * c2
    result2[1] = a3 * c1 - a1 * c3
    result2[2] = a1 * c2 - a2 * c1
    return result2
#%% Main function to calculate the Spin S1 by calculating the effective field
 @njit
def llg(S1,dt, B,eH):
    global gamma,alpha
    N_init = int(5)
    for i in range(N_init):
        heff1 = h_eff(B,S1,eH)
        X=cross1(S1,heff1)
        Y=cross2(S1,X)
        dS1dt = - gamma/(1+alpha**2) * X \
           - alpha*gamma/(1+alpha**2) * Y
        S1 += dt * dS1dt
        normS1 = np.sqrt(S1[0]*S1[0]+S1[1]*S1[1]+S1[2]*S1[2])
        S1 = S1/normS1
    Savg=np.array([0.0,0.0,0.0])
    Navg=N_init*10
    for i in range(Navg):
        heff1 = h_eff(B,S1,eH)
        X=cross1(S1,heff1)
        Y=cross2(S1,X)
        dS1dt = - gamma/(1+alpha**2) * X \
           - alpha*gamma/(1+alpha**2) * Y
        S1 += dt * dS1dt
        normS1 = np.sqrt(S1[0]*S1[0]+S1[1]*S1[1]+S1[2]*S1[2])
        S1 = S1/normS1
        Savg=Savg+S1
    Savg=Savg/Navg
    return Savg  
#%% calculating dot product
@njit
def dott(S1,K):
    result=0.0
    for i in range(3):
        result += S1[i]*K[i]
    return result


 #%% initialising magn
        magn=np.zeros([25,3]) 
        Th=[]
        Ph=[]
        B=5.0
        theta=np.linspace(0.0,np.pi,5)
        phi=np.linspace(0.0,2*np.pi,5)
    for i in range(len(phi)):
        for j in range(len(theta)):
            M,N=phi[i],theta[j]
            Th.append(N)
            Ph.append(M)

#%% calling the main fuction
for i in range(25):
    magn[i][0]=Ph[i]
    magn[i][1]=Th[i]
    eH=np.array([np.sin(Th[i])*np.cos(Ph[i]),np.sin(Th[i])*np.sin(Ph[i]),np.cos(Th[i])])
    normH = np.sqrt(eH[0]*eH[0]+eH[1]*eH[1]+eH[2]*eH[2])
    eH=eH/normH
    S1=np.array([np.sin(Th[i])*np.cos(Ph[i]),np.sin(Th[i])*np.sin(Ph[i]),np.cos(Th[i])])
    S1=llg(S1,dt,B,eH)
    K=eH*B
    Z=dott(S1,K)
    E=-Z*g*mub*s
    magn[i][2]=E

#%% printing magn
print(magn)
%timeit magn

需要注意的几点:

  1. 您似乎没有为整个操作计时。 我不确定最后一个%timeit表达式给你什么
  2. 在我的机器上,运行您给出的代码大约需要 2.5 秒
  3. 请注意,第一次调用 Numba function 非常慢,因为编译器会将代码转换为 llvm 代码。 如果将 @njit 更改为 @njit(cache=True) 则结果将被缓存,并且以后的运行不会导致编译(直到您更改函数)。 当我在我的机器上执行此操作时,第一次运行仍然需要 2.5 秒,但第二次运行在 0.12 秒内完成。
  4. 这些都无法在纯 python 中运行此 function,仅需 0.06 秒。

为什么?

在您的代码中出现这种情况的最大原因似乎是您在循环中从 Python 调用了许多小函数。 调用 numba function 会产生开销(我认为这比调用普通 python function 的开销更糟,因为需要进行类型检查。) 因此,如果您的 jitted 函数很简单,那么使用它们的好处就可以忽略不计(或者更糟的是,您会因此受到惩罚)。 如果您可以更改您的代码,以便整个逻辑(即主循环)也在一个 numba function 内,它可能会比纯 python 更快。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM