簡體   English   中英

numba 的表現(我做錯了嗎?)

[英]Performance of numba (did I do something wrong?)

我目前正在編寫一個遍歷 python 中的圖的算法。 該圖連接到一個基礎方程系統,在遍歷過程中,我必須提取和存儲一些索引。 我一開始使用 networkx 來實現它,但是因為方程系統和連接圖都變得很大,所以算法太慢了。

然后我切換到純 numpy 實現。 這更快,但仍然不夠快。 我認為 numba 會更快,但它似乎更慢。 在我測量了計算時間之后,我注意到主要問題出現在調用以下 function 期間:

@jit(nopython=True)
def add_col_to_mad_schedule_numba(rows, cols, data,  mad_schedule_col, edges, col, total_fillin_rows, total_fillin_cols,
                            total_fillin_data, nodes):

    colidx = nbfunc.where_single(cols, col)
    for k in range(len(rows[colidx])):
        if nbfunc.contained(nodes, rows[colidx][k]):
            edge_idx = nbfunc.where_single(edges[:, 0], rows[colidx][k])
            if edges[edge_idx].size != 0:
                kstart = edges[edge_idx, 0]
                ending = False
                while ending == False:
                    edges, k_filter = det_edges(kstart, edges)
                    k_filter = np.array(k_filter)
                    if k_filter.size == 0:
                        ending=True
                    else:
                        rows, cols, data, total_fillin_rows, total_fillin_cols, total_fillin_data = det_fillin(rows, cols, data, total_fillin_rows, total_fillin_cols, total_fillin_data, col, k_filter, edges)
                        first_idx, sec_idx, third_idx = mad_numba(rows, cols, edges, k_filter, col)

                        newl = np.zeros((len(edges[k_filter, 1]), 6), dtype=np.int64)
                        newl[:, 0] = edges[k_filter, 0]
                        newl[:, 1] = edges[k_filter, 1]
                        newl[:, 2] = first_idx
                        newl[:, 3] = sec_idx
                        newl[:, 4] = third_idx
                        newl[:, 5] = col


                        mad_schedule_col = np.append(mad_schedule_col, newl, axis=0)

                        kstart = edges[k_filter, 1]


    return rows, cols, data, total_fillin_rows, total_fillin_cols, total_fillin_data, mad_schedule_col[1:]

這個 function 被調用 n 次,其中 n 是方程系統中的變量數。 function 的每次運行當前需要 61 毫秒,我想問一下您是否可以看到由於錯誤使用 numba 引起的任何技術瓶頸。 例如,我仍在 function 主體中創建 numpy arrays 。 這樣的事情可能會導致表現不佳嗎?

該算法確實非常耗時,因為對於系統每一列(變量 k)中的每個非零條目,遍歷有向圖直到沒有后繼圖。 遍歷的次數並不多。 while 循環中有大約 3 次迭代。 對於每一列,也只有 3-5 個非零條目。

我也可以提供 det_fillin() 和 mad_numba() 的內容,但我認為不會發生很多事情。 我使用我自己的 numpy where() function 的 numba 等效項檢索了一些索引。

請注意,nbfunc 函數也代表與 numpy 函數等效的函數。 Where_single() 對應於 np.where 並且 contains() 只是檢查 rows[colidx][k] 是否在節點中。 所有函數都使用@jit(nopython=True) 編譯,並且沒有錯誤消息。

問題主要出現在np.append調用中。 它創建一個新的更大的數組並為每個調用復制以前的內容,這顯着增加了算法的復雜度(具有線性復雜度的算法在最壞的情況下可能變成二次方)。 同樣的事情也適用於 Numpy。

One solution to fix this problem is to use a list so to append many Numpy arrays in it and then concatenate all of them in a new bigger Numpy array.

另一種解決方案是直接創建一個大小合適的大數組然后通過將子數組復制到大數組來填充它 這個解決方案要快得多,但它需要提前知道最終數組的大小。 當代碼可以並行執行時,速度特別快。

另一種解決方案是使用前一個,並進行兩個小的更改:創建最大可能大小的大數組,計算 function 僅返回它的實際填充子集(視圖)。 這個解決方案要求大數組的大小是有界的,並且這個范圍不能比返回視圖的實際平均大小大很多。 這通常比使用列表更快,但它需要更多的 memory。


注釋和備注

請注意:請注意mad_schedule_col =...不會寫入/修改傳入參數的數組,它會創建一個新數組並設置變量mad_schedule_col以引用新創建的數組。 如果你想改變輸入數組,那么你需要寫入它。 如果您不提前知道大小,那么最好簡單地返回修改后的mad_schedule_col

另請注意,如果沒有提供簽名Numba 函數將被延遲編譯(當 function 第一次執行時)並且編譯時間可能會很慢。 提供簽名會導致 Numba 急切地編譯 function(當定義了 function 時)。

請注意,如果newl.shape[0]很大,那么您的算法可能會受到內存限制。 如果是這樣,那么使用 Numba 不會快很多,因為 memory 已經飽和了。

這只是部分答案,因為我根據您的建議更新了 function。 你對 np.append 函數是正確的。 我用大的 arrays 或列表替換了它們,並且每次調用的時間縮短到 25 毫秒。 這仍然是很多時間,我想。 此外,添加簽名稍微增加了計算時間,所以不幸的是這並沒有幫助。

但是我注意到一些奇怪的事情:我為您准備了整個代碼以及必要的輸入變量,並使用 np.save 將它們存儲起來,因此您可以仔細查看它。

我還添加了一個小測試腳本,它只是在 for 循環中運行 function add_col_to_mad_schedule_numba,以便分析器將其放在列表的頂部。 使用存儲的輸入運行此測試腳本后,我注意到在測試腳本中,function 的每次調用時間僅為 2 毫秒,而 function 嵌入到更大的腳本中的計算時間為 25 毫秒。 這怎么可能?

似乎還是有什么不對勁。 我嘗試構建的是線性方程組的並行稀疏求解器。 沒什么新鮮的,scipy 實現了 SuperLU,它在某些階段也使用圖遍歷(但在其他方面工作不同)。 然而,在 SuperLU 中求解的相同方程組只需要幾毫秒左右的時間。 我知道使用 numba 可能無法實現這些時間,但必須有一些選項來進一步減少此圖遍歷的時間。

如果您查看了代碼並想要運行測試腳本,請注意:

  1. 您必須在主要部分的 with open 命令中修改文件路徑。
  2. 如果您查看 numba 函數,您會注意到仍有一些命令我使用 np.append。 這是因為其他方式會導致奇怪的錯誤消息。 我稍后會處理這個。

可以通過以下鏈接下載代碼和輸入:

https://drive.google.com/drive/folders/1hBvFmmZtJ2rRf5VaQFJX8pjCQtLSD9-0?usp=sharing

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM