簡體   English   中英

如何用numba加速這個python函數?

[英]How to speed up this python function with numba?

我正在嘗試加速這個 python 函數:

def twoFreq_orig(z, source_z, num, den, matrix, e):
    Z1, Z2 = np.meshgrid(source_z, np.conj(z))
    Z1 **= num
    Z2 **= den - 1
    M = (e ** ((num + den - 2) / 2.0)) * Z1 * Z2
    return np.sum(matrix * M, 1)

其中zsource_znp.ndarray (1d, dtype=np.complex128 ), numdennp.ndarray (2d, dtype=np.float64 ), matrixnp.ndarray (2d, dtype=np.complex128 ) e是一個np.float64

我對Numba沒有太多經驗,但在閱讀了一些教程后,我想出了這個實現:

@nb.jit(nb.f8[:](nb.c16[:], nb.c16[:], nb.f8[:, :], nb.f8[:, :], nb.c16[:, :], nb.f8))
def twoFreq(z, source_z, num, den, matrix, e):
    N1, N2 = len(z), len(source_z)
    out = np.zeros(N1)
    for r in xrange(N1):
        tmp = 0
        for c in xrange(N2):
            n, d = num[r, c], den[r, c] - 1
            z1 = source_z[c] ** n
            z2 = z[r] ** d
            tmp += matrix[r, c] * e ** ((n + d - 1) / 2.0) * z1 * z2
        out[r] = tmp
    return out

不幸的是,Numba 實現不是加速,而是比原始實現慢了幾倍。 我不知道如何正確使用 Numba。 任何 Numba 大師都可以幫我一把嗎?

實際上,我認為在不深入了解數組屬性的情況下,您無法做很多事情來加速 numba 函數(是否有一些數學技巧可以更快地完成某些計算)。

但我注意到一個錯誤:例如,您沒有在 numba 版本中結合數組,我編輯了一些行以使其更加流暢(其中一些可能只是品味)。 我在適當的地方加入了評論:

@nb.njit
def twoFreq(z, source_z, num, den, matrix, e):
    #Replace z with conjugate of z (otherwise the result is wrong!)
    z = np.conj(z)
    # Size instead of len() don't know if it actually makes a difference but it's cleaner
    N1, N2 = z.size, source_z.size
    # Must be zeros_like otherwise you create a float array where you want a complex one
    out = np.zeros_like(z)
    # I'm using python 3 so you need to replace this by xrange later
    for r in range(N1):
        for c in range(N2):
            n, d = num[r, c], den[r, c] - 1
            z1 = source_z[c] ** n
            z2 = z[r] ** d
            # Multiply with 0.5 instead of dividing by 2
            # Work on the out array directly instead of a tmp variable
            out[r] += matrix[r, c] * e ** ((n + d - 1) * 0.5) * z1 * z2
    return out

def twoFreq_orig(z, source_z, num, den, matrix, e):
    Z1, Z2 = np.meshgrid(source_z, np.conj(z))
    Z1 **= num
    Z2 **= den - 1
    M = (e ** ((num + den - 2) / 2.0)) * Z1 * Z2
    return np.sum(matrix * M, 1)


numb = 1000
z = np.random.uniform(0,1,numb) + 1j*np.random.uniform(0,1,numb)
source_z = np.random.uniform(0,10,numb) + 1j*np.random.uniform(0,1,numb)
num = np.random.uniform(0,1,(numb,numb))
den = np.random.uniform(0,1,(numb,numb))
matrix = np.random.uniform(0,1,(numb,numb)) + 1j*np.random.uniform(0,1,(numb, numb))
e = 5.5

# This failed for your initial version:
np.testing.assert_array_almost_equal(twoFreq(z, source_z, num, den, matrix, e),
                                     twoFreq_orig(z, source_z, num, den, matrix, e))

我電腦上的運行時間是:

%timeit twoFreq(z, source_z, num, den, matrix, e)

1 個循環,最好的 3 個:每個循環 246 毫秒

%timeit twoFreq_orig(z, source_z, num, den, matrix, e)

1 個循環,最好的 3 個:每個循環 344 毫秒

它比您的 numpy 解決方案快約 30%。 但是我認為通過巧妙地使用廣播可以使 numpy 解決方案更快一點。 但是,盡管如此,我獲得的大部分加速都來自省略簽名:請注意,您可能使用 C 連續數組,但您已經給出了任意排序(因此 numba 可能會慢一點,具體取決於計算機體系結構)。 可能通過定義c16[::-1]你會得到相同的速度,但通常只是讓 numba 推斷類型,它可能會盡可能快。 例外:您希望每個變量具有不同的精度輸入(例如,您希望zcomplex128complex64

當您的 numpy 解決方案內存不足時,您將獲得驚人的加速(因為您的 numpy 解決方案是矢量化的,它將需要更多的 RAM!)當numb = 5000 ,numba 版本比 numpy 版本快大約 3 倍。


編輯:

通過巧妙的廣播,我的意思是

np.conj(z[:,None]**(den-1)) * source_z[None, :]**(num)

等於

z1, z2 = np.meshgrid(source_z, np.conj(z))
z1**(num) * z2**(den-1)

但是對於第一個變體,您只能對numb元素進行冪運算,而您有一個(numb, numb)形狀的數組,因此您執行的“冪”運算比必要的要多得多(即使我猜對於小數組,結果可能主要是緩存的)而且不是很貴)

沒有mgrid numpy 版本(產生相同的結果)如下所示:

def twoFreq_orig2(z, source_z, num, den, matrix, e):
    z1z2 = source_z[None,:]**(num) * np.conj(z)[:, None]**(den-1)
    M = (e ** ((num + den - 2) / 2.0)) * z1z2
    return np.sum(matrix * M, 1)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM