简体   繁体   中英

Cython not fast enough

I rewrote my python loop in cython expecting a large improvement in speed. I only get about a factor of four. Am I doing something wrong? This is the code without cython:

import numpy as np
import itertools as itr
import math

def Pk (b, f, mu, k): # k is in Mpc
    isoPk = 200*math.exp(-(k-0.02)**2/2/0.01**2) # Isotropic power spectrum
    power = (b+mu**2*f)**2*isoPk
    return power

def Gendk (N, kvec, Pk, b, f, deltak3d):
    Nhalf = int(N/2)
    for xx, yy, zz in itr.product(range(0,N), range(0,N), range(0,Nhalf+1)):
        kx = kvec[xx]
        ky = kvec[yy]
        kz = kvec[zz]
        kk = math.sqrt(kx**2+ky**2+kz**2)
        if kk == 0:
            continue
        mu = kz/kk
        power = Pk(b, f, mu, kk)
        if power==0:
            deltaRe = 0 
            deltaIm = 0
        else:
            deltaRe = np.random.normal(0, power/2.0)
            if (xx==0 or xx==Nhalf) and (yy==0 or yy==Nhalf) and (zz==0 or zz==Nhalf):
                deltaIm = 0
            else:
                deltaIm = np.random.normal(0, power/2.0)
        x_conj = (2*N-xx)%N
        y_conj = (2*N-yy)%N
        z_conj = (2*N-zz)%N
        deltak3d[xx,yy,zz] = deltaRe + deltaIm*1j
        deltak3d[x_conj,y_conj,z_conj] = deltaRe - deltaIm*1j 

Ntot = 300000
L = 1000 
N = 128 
Nhalf = int(N/2)
kmax = 5.0 
dk = kmax/N
kvec = np.fft.fftfreq(N, L/N)
dL = L/N
deltak3d = np.zeros((N,N,N), dtype=complex)
deltak3d[0,0,0] = Ntot
Gendk(N, kvec, Pk, 2, 1, deltak3d)

This is the version with cython:

import numpy as np
import pyximport; pyximport.install(setup_args={"include_dirs":np.get_include()})
import testGauss as tG

Ntot = 300000
L = 1000 
N = 128 
Nhalf = int(N/2)
kmax = 5.0 
dk = kmax/N
kvec = np.fft.fftfreq(N, L/N)
dL = L/N 
deltak3d = np.zeros((N,N,N), dtype=complex)
deltak3d[0,0,0] = Ntot
tG.Gendk(N, kvec, tG.Pk, 2, 1, deltak3d)

and the testGauss.pyx file is:

import math
import numpy as np
cimport numpy as np
import itertools as itr

def Pk (double b, double f, double mu, double k): # k is in Mpc
    cdef double isoPk, power
    isoPk = 200*math.exp(-(k-0.02)**2/2/0.01**2) # Isotropic power spectrum
    power = (b+mu**2*f)**2*isoPk
    return power

def Gendk (int N, np.ndarray[np.float64_t,ndim=1] kvec, Pk, double b, double f, np.ndarray[np.complex128_t,ndim=3] deltak3d):
    cdef int Nhalf = int(N/2)
    cdef int xx, yy, zz
    cdef int x_conj, y_conj, z_conj
    cdef double kx, ky, kz, kk
    cdef mu
    cdef power
    cdef deltaRe, deltaIm
    for xx, yy, zz in itr.product(range(0,N), range(0,N), range(0,Nhalf+1)):
        kx = kvec[xx]
        ky = kvec[yy]
        kz = kvec[zz]
        kk = math.sqrt(kx**2+ky**2+kz**2)
        if kk == 0:
            continue
        mu = kz/kk
        power = Pk(b, f, mu, kk)
        if power==0:
            deltaRe = 0 
            deltaIm = 0
        else:
            deltaRe = np.random.normal(0, power/2.0)
            if (xx==0 or xx==Nhalf) and (yy==0 or yy==Nhalf) and (zz==0 or zz==Nhalf):
                deltaIm = 0
            else:
                deltaIm = np.random.normal(0, power/2.0)
        x_conj = (2*N-xx)%N
        y_conj = (2*N-yy)%N
        z_conj = (2*N-zz)%N
        deltak3d[xx,yy,zz] = deltaRe + deltaIm*1j
        deltak3d[x_conj,y_conj,z_conj] = deltaRe - deltaIm*1j 

Thank you very much in advance!

Use cProfile to profile your Python code. Maybe the most CPU intensive tasks are in NumPy already. Then there is not so much to gain from Cython.

You could get some speedup by replacing

import math

with

from libc cimport math

That will avoid a python function call when you do sqrt and exp, replacing it with a direct c call (which should be a lot faster).

I'm also slightly concerned at the calls to np.random.normal inside your loop, which add a reasonable python overhead each time. It might well be quicker to call this before the loop to generate a large array of random numbers (with the overhead of a single python call) then overwrite them with 0 if they aren't needed inside the loop.

The general advice for optimising Cython still applies: run

cython -a your_file.pyx

Look at the HTML, and worry about bits highlighted yellow (but only if they're called often)

Turning your code (slightly modified) in a native module with Pythran gives me a x50 speedup.

import numpy as np
import itertools as itr
import math
from random import gauss as normal

def Pk (b, f, mu, k): # k is in Mpc
    isoPk = 200*math.exp(-(k-0.02)**2/2/0.01**2) # Isotropic power spectrum
    power = (b+mu**2*f)**2*isoPk
    return power

#pythran export Gendk(int, float[], int, int, complex[][][])
def Gendk (N, kvec, b, f, deltak3d):
    Nhalf = int(N/2)
    for xx, yy, zz in itr.product(range(0, N), range(0, N), range(0, Nhalf+1)):
        kx = kvec[xx]
        ky = kvec[yy]
        kz = kvec[zz]
        kk = math.sqrt(kx**2+ky**2+kz**2)
        if kk == 0:
            continue
        mu = kz/kk
        power = Pk(b, f, mu, kk)
        if power == 0:
            deltaRe = 0 
            deltaIm = 0
        else:
            # deltaRe = np.random.normal(0, power/2.0)
            deltaRe = normal(0, power/2.0)
            if (xx == 0 or xx == Nhalf) and (yy == 0 or yy == Nhalf) and (zz == 0 or zz == Nhalf):
                deltaIm = 0
            else:
                #deltaIm = np.random.normal(0, power/2.0)
                deltaIm = normal(0, power/2.0)
        x_conj = (2*N-xx)%N
        y_conj = (2*N-yy)%N
        z_conj = (2*N-zz)%N
        deltak3d[xx, yy, zz] = deltaRe + deltaIm*1j
        deltak3d[x_conj, y_conj, z_conj] = deltaRe - deltaIm*1j 

Compiled with:

$ pythran tg.py

And tested with:

$ python -m timeit -s 'import numpy as np; Ntot = 30000; L = 1000; N = 12; Nhalf = int(N/2); kmax = 5.0; dk = kmax/N; kvec = np.fft.fftfreq(N, L/N); dL = L/N; deltak3d = np.zeros((N, N, N), dtype=complex); deltak3d[0, 0, 0] = Ntot; from tg import Gendk' 'Gendk(N, kvec, 2, 1, deltak3d)'

I get 10 loops, best of 3: 29.4 msec per loop for the CPython run and 1000 loops, best of 3: 587 usec per loop for the Pythran run.

Disclaimer: I'm a Pythran dev .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM