![](/img/trans.png)
[英]Why is Strassen matrix multiplication so much slower than standard matrix multiplication?
[英]Matrix multiplication: Strassen vs. Standard
我嘗試使用C ++實現Strassen算法進行矩陣乘法,但結果不是我所期望的。 正如您所看到的,strassen總是花費更多時間,然后標准實現,並且只有2的冪的維度與標准實現一樣快。 什么地方出了錯?
matrix mult_strassen(matrix a, matrix b) {
if (a.dim() <= cut)
return mult_std(a, b);
matrix a11 = get_part(0, 0, a);
matrix a12 = get_part(0, 1, a);
matrix a21 = get_part(1, 0, a);
matrix a22 = get_part(1, 1, a);
matrix b11 = get_part(0, 0, b);
matrix b12 = get_part(0, 1, b);
matrix b21 = get_part(1, 0, b);
matrix b22 = get_part(1, 1, b);
matrix m1 = mult_strassen(a11 + a22, b11 + b22);
matrix m2 = mult_strassen(a21 + a22, b11);
matrix m3 = mult_strassen(a11, b12 - b22);
matrix m4 = mult_strassen(a22, b21 - b11);
matrix m5 = mult_strassen(a11 + a12, b22);
matrix m6 = mult_strassen(a21 - a11, b11 + b12);
matrix m7 = mult_strassen(a12 - a22, b21 + b22);
matrix c(a.dim(), false, true);
set_part(0, 0, &c, m1 + m4 - m5 + m7);
set_part(0, 1, &c, m3 + m5);
set_part(1, 0, &c, m2 + m4);
set_part(1, 1, &c, m1 - m2 + m3 + m6);
return c;
}
g++ main.cpp matrix.cpp -o matrix -O3
。
一些想法:
好的我不是這個領域的專家,但是這里可能還有其他問題比處理速度快。 首先,strassen方法使用更多堆棧並具有更多函數調用,這增加了內存移動。 你的堆棧越大,你就會受到一定的懲罰,因為它需要從操作系統請求更大的幀。 另外,您使用動態分配,這也是一個問題。
嘗試使用固定大小(帶模板參數)矩陣類? 這至少會解決分配問題。
注意:我不確定它是否與您的代碼一起正常運行。 您的矩陣類使用指針但沒有復制構造函數或賦值運算符。 你也在最后泄漏記憶,因為你沒有析構函數......
與O(N ^ 3)常規相比,Strassen的大O是O(N ^ log 7),即log 7 base 2,其略小於3。
這是您需要進行的乘法次數。
它假設你沒有任何其他成本,並且也應該“更快”,因為N足夠大,而你可能沒有。
你的大部分實現都是創建了很多子矩陣,我的猜測就是你存儲它們的方式,你每次執行此操作都需要分配內存和復制。 有一些“切片”矩陣和邏輯轉置矩陣,如果你可以幫助你優化可能是你的過程中最慢的部分。
我對Stassen multiplcation實現的速度有多快感到震驚:
http://ezekiel.vancouver.wsu.edu/~cs330/lectures/linear_algebra/mm/mm.c
當n = 1024時,我的機器上的速度提高了近16倍。 我可以解釋這個加速的唯一方法是我的算法更加緩存友好 - 也就是說,它專注於矩陣的一小部分,因此數據更加本地化。
C ++實現中的開銷可能過高 - 編譯器生成的臨時數超過了實際需要的數量。 我的實現試圖通過盡可能重用內存來最小化這個。
遠射,但您是否認為標准乘法可能會被編譯器優化? 你能關閉優化嗎?
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.