[英]Slow Numba performance on convolution
我正在使用以下代碼使用 Numba 在兩個 3D arrays 的組件之間進行卷積:
@jit( nopython=True, parallel=True, nogil=True )
def calculate_convs_products( x_grid, fn, gn, indices_pairs_array):
total_items = fn.shape[1]
total_points = x_grid.shape[0]
#result = []
final_res = np.zeros( (fn.shape[1],fn.shape[-1]), dtype=fn.dtype )
#indices_pairs_array is an array of the type [[0,0],[0,1],[0,2],[1,0],[1,1],
# [1,2]...]
#with all N**2 pairings of N integers from 0 to N-1.
tot_pairs = indices_pairs_array.shape[0]
for l in prange( tot_pairs ):
f = fn[indices_pairs_array[l,0]]
g = gn[indices_pairs_array[l,1]]
result = []
for k in range( total_items ):
res_k = []
for x_i in range( total_points ):
index = x_i - int( total_points/2 )
gs = np.roll( g[k,:], index )
if ( index < 0 ):
gs[ index: ] = g[ k, -1 ]
elif ( index > 0 ):
gs[ 0:index ] = g[ k, 0 ]
res = trapzl( f[ k, : ]*gs , x_grid )
res_k.append( res )
result.append( res_k )
result = np.array( result )
final_res += result
return final_res
@jit( nopython=True )
def trapzl(y, x):
"Pure python version of trapezoid rule."
s = 0
for i in range(1, len(x)):
s += (x[i]-x[i-1])*(y[i]+y[i-1])
return s/2
請注意,進入 function 的 arrays fn
和gn
是 3D,並且卷積發生在最后一個軸上。 但是,使用上述實現時,我的表現真的很差。 此外,Numba 沒有正確並行化外部l
循環。 什么可能會減慢這一速度,我們如何才能提高效率?
首先,您不能以您的方式並行化循環,因為它會導致競爭條件。 實際上,共享數組final_res
是由多個線程讀寫的。 在 Numba 中,線程需要在不同的 memory 區域數組上工作,或者讀取其他線程未觸及的 memory 區域。 否則,必須進行同步(AFAIK,Numba 不提供同步線程的高級方法,因為它基於簡單的 fork-join model,因此在這種情況下應該需要多個並行循環)。
請注意,由於append
調用,基於k
和基於x_i
的循環也不能使用prange
簡單地並行化,並且並行化trapzl
不應該是有效的,因為這個 function 所花費的時間可能應該很小(並且向線程發送工作很昂貴) . 因此,這意味着代碼不能簡單有效地並行化。
請注意,在 Numba 中,所有變量都是類型化的,它們只能有 1 個唯一類型。 事情是result
被初始化為一個類型化的列表,你做result = np.array( result )
這意味着它也應該是一個np.array
。 即使這樣的行可以工作,它也會引入愚蠢的低效隱式轉換。 您可以使用另一個變量名來解決這個問題。 我希望 Numba 抱怨代碼無法編譯(導致純 Python 執行)。
希望可以改進順序代碼。 首先,您不需要使用(慢)列表,因為您知道循環迭代范圍。 您可以直接創建大小合適的數組result
(total_items, total_points)
並填充它。 實際上,您甚至不能創建數組,因為它僅用於在final_res
中累積一些值。 如果 final_res 的大小與result
相同(即沒有廣播),您可以直接在final_res
中累積值。
最后,請注意 Numba function 將在 function 的第一次執行期間延遲編譯。 這可能會影響您的基准測試,因為與執行時間相比,編譯這樣的 function 可能會非常慢。 您可以通過提供 function 的簽名來指定輸入類型,以便 Numba 可以急切地編譯它(有關更多信息,請參見此處)。
可能還有更多可以應用的優化,但是如果沒有一些有效的輸入值(尤其是輸入類型)就很難判斷。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.