[英]Speed up integration of sums in Python
我正在嘗試加速 Python 中的特定(數字)積分。 我已經在 Mathematica 中進行了評估,需要 14 秒。 在 python 中需要 15.6 分鍾!
我要評估的積分形式如下:
python代碼如下:
from mpmath import hermite
def light_nm( dipol, n, m, t):
mat_elem = light_amp(n)*light_amp_conj(m)*coef_ground( dipol, n,t)*np.conj(coef_ground( dipol, m,t)) + \
light_amp(n+1)*light_amp_conj(m+1)*coef_excit( dipol, n+1,t)*np.conj(coef_excit( dipol, m+1,t))
return mat_elem
def light_nm_dmu( dipol, n, m, t):
mat_elem = light_amp(n)*light_amp_conj(m)*(coef_ground_dmu( dipol, n,t)*conj(coef_ground( dipol, m,t)) + coef_ground( dipol, n,t)*conj(coef_ground_dmu( dipol, m,t)) )+ \
light_amp(n+1)*light_amp_conj(m+1)*(coef_excit_dmu( dipol, n+1,t)*np.conj(coef_excit( dipol, m+1,t)) + coef_excit( dipol, n+1,t)*conj(coef_excit_dmu( dipol, m+1,t)))
return mat_elem
def prob(dipol, t, x, thlo, cutoff, n, m):
temp = complex( light_nm(dipol, n, m, t)* cmath.exp(1j*thlo*(n-m)-x**2)*\
hermite(n,x)*hermite(m,x)/math.sqrt(2**(n+m)*math.factorial(m)*math.factorial(n)*math.pi))
return np.real(temp)
def derprob(dipol, t, x, thlo, cutoff, n, m):
temp = complex( light_nm_dmu(dipol, n, m, t)* cmath.exp(1j*thlo*(n-m)-x**2)*\
hermite(n,x)*hermite(m,x)/math.sqrt(2**(n+m)*math.factorial(m)*math.factorial(n)*math.pi))
if np.imag(temp)>10**(-6):
print(t)
return np.real(temp)
def integrand(dipol, t, thlo, cutoff,x):
return 1/np.sum(np.array([ prob(dipol,t,x,thlo,cutoff,n,m) for n,m in product(range(cutoff),range(cutoff))]))*\
np.sum(np.array([ derprob(dipol,t,x,thlo,cutoff,n,m) for n,m in product(range(cutoff),range(cutoff))]))**2
def cfi(dipol, t, thlo, cutoff, a):
global alpha
alpha = a
temp_func_real = lambda x: np.real(integrand(dipol,t, thlo, cutoff, x))
temp_real = integ.quad(temp_func_real, -8, 8)
return temp_real[0]
Hermite 函數是從 mpmath 庫中調用的。 有什么辦法可以讓這段代碼運行得更快嗎?
謝謝!
更新:我添加了整個代碼。 (我很抱歉延遲)“light_nm_dmu”功能類似於“light_nm”。 我嘗試了答案,但在 light_amp 函數中收到錯誤“TypeError:只有大小為 1 的數組可以轉換為 Python 標量”,因此我對 prob 和 derprob 進行了矢量化。
對於相同的評估,新時間為 886.7085871696472 = 14.8 分鍾 (cfi(0.1,1,0,40,1))
建議使用:
使用緩存加速大量數字的階乘計算,即math.factorial 是否被記住? (Domenico De Felice 修改答案)
更新代碼
# use cached factorial function
def prob(dipol, t, x, thlo, cutoff, n, m):
temp = complex( light_nm(dipol, n, m, t)* cmath.exp(1j*thlo*(n-m)-x**2)*\
hermite(n,x)*hermite(m,x)/math.sqrt(2**(n+m)*factorial(m)*factorial(n)*math.pi))
return np.real(temp)
# Vectorize computation
def integrand(dipol, t, thlo, cutoff,x):
xaxis = np.arange(0, cutoff)
yaxis = np.arange(0, cutoff)
return 1/np.sum(prob(dipol,t,x,thlo,cutoff,xaxis[:, None] , yaxis[None, :]))*\
np.sum(derprob(dipol,t,x,thlo,cutoff,xaxis[:, None] , yaxis[None, :]))**2
# unchanged
def cfi(dipol, t, thlo, cutoff, a):
global alpha
alpha = a
temp_func_real = lambda x: np.real(integrand(dipol,t, thlo, cutoff, x))
temp_real = integ.quad(temp_func_real, -8, 8)
return temp_real[0]
# Cached factorial
def factorial(num, fact_memory = {0: 1, 1: 1, 'max': 1}):
' Cached factorial since we're computing on lots of numbers '
# Factorial is defined only for non-negative numbers
assert num >= 0
if num <= fact_memory['max']:
return fact_memory[num]
for x in range(fact_memory['max']+1, num+1):
fact_memory[x] = fact_memory[x-1] * x
fact_memory['max'] = num
return fact_memory[num]
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.