簡體   English   中英

RK4 用 numba 加速

[英]RK4 speed up with numba

我想用 numba 制作 RK4 以加快速度。 我是使用 numba 的初學者。 為什么 numba 不能理解我的代碼?

簡單的代碼如下

在 swing.py

@numba.jit(nopython=True)
def RK4(func, t_end, X0, dt):
    t = np.arange(0,t_end, dt, dtype=np.float64)
    X  = np.zeros((t.shape[0], X0.shape[0]))
    X[0] = X0
    hdt = dt*.5
    for i in range(t.shape[0]-1):
        t1 = t[i]
        x1 = X[i]
        k1 = func(t[i], X[i])
        
        t2 = t[i] + hdt
        x2 = X[i] + hdt * k1
        k2 = func(t2, x2)
        
        t3 = t[i] + hdt
        x3 = X[i] + hdt * k2
        k3 = func(t3, x3)
        
        t4 = t[i] + dt
        x4 = X[i] + dt * k3
        k4 = func(t4, x4)
        X[i+1] = X[i] + dt / 6. * (k1 + 2. * k2 + 2. * k3 + k4)
    return X

# dyummy function for test
@numba.jit(nopython=True)
def fff(t, X):
    t = 1
    X = 3
    res = [0]
    res.append(t*X)
    return res

運行的主要代碼。

import numpy as np
import numba

swing.RK4(swing.fff, 10, np.array([0,1]), 0.1)

以下錯誤消息:但我無法理解這個簡單代碼中的不正確之處。

---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
Input In [2], in <cell line: 1>()
----> 1 swing.RK4(swing.fff, 10, np.array([0,1]), 0.1)

File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File ~/miniconda3/lib/python3.9/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function mul>) found for signature:
 
 >>> mul(float64, list(int64)<iv=[0]>)
 
There are 14 candidate implementations:
  - Of which 12 did not match due to:
  Overload of function 'mul': File: <numerous>: Line N/A.
    With argument(s): '(float64, list(int64)<iv=None>)':
   No match.
  - Of which 2 did not match due to:
  Operator Overload in function 'mul': File: unknown: Line unknown.
    With argument(s): '(float64, list(int64)<iv=None>)':
   No match for registered cases:
    * (int64, int64) -> int64
    * (int64, uint64) -> int64
    * (uint64, int64) -> int64
    * (uint64, uint64) -> uint64
    * (float32, float32) -> float32
    * (float64, float64) -> float64
    * (complex64, complex64) -> complex64
    * (complex128, complex128) -> complex128

During: typing of intrinsic-call at /disk/disk2/youngjin/workspace/workspace/DS/Inference/MCMC/Swing/swing.py (36)

File "swing.py", line 36:
def RK4(func, t_end, X0, dt):
    <source elided>
        t2 = t[i] + hdt
        x2 = X[i] + hdt * k1
        ^

找到原因和解決方法了嗎

解決方案

在 mycode.py 中
import numpy as np
from scipy import integrate
from typing import Union, List
import numba

def AdjMtoAdjL(adjM: np.ndarray) -> list:
    return [np.argwhere(adjM[:,i] > 0).flatten() for i in range(len(adjM))]
def AdjMtoEdgL(adjM: np.ndarray) -> np.ndarray:
    return np.argwhere(adjM > 0)

@numba.jit(nopython=True)
# def swing(t, y, model_param, model):
def swing(t, y, phi, m, gamma, P, K, model):
    if model == "swing":
        T, O = y
        T = np.array([T])
        O = np.array([O])
    else:
        T = y

    # Get Interaction
    Interaction = K*np.sin(T-phi)
    """
    \dot{\theta} &= \omega \\
    \dot{\omega} &= \frac{1}{m}(P-\gamma\omega+\Sigma K\sin(\theta-\phi))
    """
    if model == "swing":
        dT = O
        dO = 1/m*(P - gamma*O - Interaction)
        dydt = np.concatenate((dT, dO))#, dtype=np.float64)
    else:
        dydt = P + Interaction
    return dydt

@numba.jit(nopython=True)
def RK4(func, t_end, X0, dt, phi, m, gamma, P, K, model):
    t = np.arange(0,t_end, dt, dtype=np.float64)
    X  = np.zeros((t.shape[0], X0.shape[0]))
    X[0] = X0
    hdt = dt*.5
    for i in range(t.shape[0]-1):
        t1 = t[i]
        x1 = X[i]
        k1 = func(t[i], X[i], phi, m, gamma, P, K, model)
        
        t2 = t[i] + hdt
        x2 = X[i] + hdt * k1
        k2 = func(t2, x2, phi, m, gamma, P, K, model)
        
        t3 = t[i] + hdt
        x3 = X[i] + hdt * k2
        k3 = func(t3, x3, phi, m, gamma, P, K, model)
        
        t4 = t[i] + dt
        x4 = X[i] + dt * k3
        k4 = func(t4, x4, phi, m, gamma, P, K, model)
        X[i+1] = X[i] + dt / 6. * (k1 + 2. * k2 + 2. * k3 + k4)
    return X
主代碼.ipynb
import networkx as nx
import os
import multiprocessing as mp
from multiprocessing import Pool
import time
import numpy as np
import swing

def multiprocess(Ngrid=101, t_end=30., omega_lim=30, dt=.001, n_cpu=19):
    start = int(time.time())
    
    T_range = np.linspace(0, 2*np.pi, Ngrid)
    O_range = np.linspace(-omega_lim, omega_lim, Ngrid)
    
    paramss = []
    for theta in T_range:
        for omega in O_range:
            y0 = np.hstack((
                theta,  # Theta
                omega,  # Omega
            ))
            params = {}
            params['sparam'] = Swing_Parameters
            params['t_end'] = t_end
            params['init'] = y0
            params['dt'] = dt
            paramss.append(params)
            del([[params]])

    p = Pool(processes=n_cpu)
    result = p.map(solve_func, paramss)
    
    end = int(time.time())
    print("***run time(sec) : ", end-start)
    print("Number of Core : " + str(n_cpu))
    return result

def solve_func(params):
    Swing_Parameters = params['sparam']
    t_end = params['t_end']
    y0 = params['init']
    dt = params['dt']
    
    # model = swing.SwingSingle(**Swing_Parameters)
    t_eval = np.arange(0,t_end, dt)
    # solution = integrate.solve_ivp(model, [0,t_end], y0, dense_output=False, 
                       # t_eval=t_eval, vectorized=True, method="LSODA")
    phi = Swing_Parameters["phi"]
    m = Swing_Parameters["m"]
    gamma = Swing_Parameters["gamma"]
    P = Swing_Parameters["P"]
    K = Swing_Parameters["K"]
    _model = Swing_Parameters["model"]

    solution = swing.RK4(swing.swing, t_end, y0, dt, phi, m, gamma, P, K, _model)
    return solution
    
Ngrid = 301
t_end = 24.
omega_lim = 30
dt = .05

Ngrid = 301
t_end = 24.
omega_lim = 30
dt = .05

Swing_Parameters = {
    "phi": np.pi,
    "m": 1.,
    "gamma": 0.3,
    "P": 2.,
    "K": 8.,
    "model": "swing"
}

model = swing.SwingSingle(**Swing_Parameters)

res = multiprocess(Ngrid=Ngrid, t_end=t_end, omega_lim=omega_lim, dt=dt, n_cpu=19)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM