簡體   English   中英

矩陣乘法:Strassen vs. Standard

[英]Matrix multiplication: Strassen vs. Standard

我嘗試使用C ++實現Strassen算法進行矩陣乘法,但結果不是我所期望的。 正如您所看到的,strassen總是花費更多時間,然后標准實現,並且只有2的冪的維度與標准實現一樣快。 什么地方出了錯? 替代文字

matrix mult_strassen(matrix a, matrix b) {
if (a.dim() <= cut)
    return mult_std(a, b);

matrix a11 = get_part(0, 0, a);
matrix a12 = get_part(0, 1, a);
matrix a21 = get_part(1, 0, a);
matrix a22 = get_part(1, 1, a);

matrix b11 = get_part(0, 0, b);
matrix b12 = get_part(0, 1, b);
matrix b21 = get_part(1, 0, b);
matrix b22 = get_part(1, 1, b);

matrix m1 = mult_strassen(a11 + a22, b11 + b22); 
matrix m2 = mult_strassen(a21 + a22, b11);
matrix m3 = mult_strassen(a11, b12 - b22);
matrix m4 = mult_strassen(a22, b21 - b11);
matrix m5 = mult_strassen(a11 + a12, b22);
matrix m6 = mult_strassen(a21 - a11, b11 + b12);
matrix m7 = mult_strassen(a12 - a22, b21 + b22);

matrix c(a.dim(), false, true);
set_part(0, 0, &c, m1 + m4 - m5 + m7);
set_part(0, 1, &c, m3 + m5);
set_part(1, 0, &c, m2 + m4);
set_part(1, 1, &c, m1 - m2 + m3 + m6);

return c; 
}


程序
matrix.h http://pastebin.com/TYFYCTY7
matrix.cpp http://pastebin.com/wYADLJ8Y
main.cpp http://pastebin.com/48BSqGJr

g++ main.cpp matrix.cpp -o matrix -O3

一些想法:

  • 您是否優化過它以考慮用零填充非功率的兩個大小的矩陣? 我認為該算法假設您不打擾這些術語的倍增。 這就是為什么你得到的運行時間在2 ^ n和2 ^(n + 1)-1之間的平坦區域。 通過不將您知道的術語乘以零,您應該能夠改進這些區域。 或許Strassen只能用於2 ^ n大小的矩陣。
  • 考慮到“大”矩陣是任意的,並且該算法僅略微優於天真的情況,O(N ^ 3)對O(N ^ 2.8)。 在嘗試更大的矩陣之前,您可能看不到可衡量的收益。 例如,我做了一些有限元建模,其中10,000x10,000矩陣被認為是“小”。 很難從你的圖表中看出來,但看起來在Stassen案例中511案例可能會更快。
  • 嘗試使用各種優化級別進行測試,包括根本不進行優化。
  • 該算法似乎假設乘法比加法要昂貴得多。 這在40年前首次開發時確實如此,但我相信更現代的處理器,加法和乘法之間的差異變小了。 這可能會降低算法的有效性,這似乎會減少乘法,但會增加相加。
  • 你有沒有看過其他一些Strassen實現的想法? 嘗試對已知良好的實現進行基准測試,以確切了解您可以獲得多快的速度。

好的我不是這個領域的專家,但是這里可能還有其他問題比處理速度快。 首先,strassen方法使用更多堆棧並具有更多函數調用,這增加了內存移動。 你的堆棧越大,你就會受到一定的懲罰,因為它需要從操作系統請求更大的幀。 另外,您使用動態分配,這也是一個問題。

嘗試使用固定大小(帶模板參數)矩陣類? 這至少會解決分配問題。

注意:我不確定它是否與您的代碼一起正常運行。 您的矩陣類使用指針但沒有復制構造函數或賦值運算符。 你也在最后泄漏記憶,因為你沒有析構函數......

與O(N ^ 3)常規相比,Strassen的大O是O(N ^ log 7),即log 7 base 2,其略小於3。

這是您需要進行的乘法次數。

它假設你沒有任何其他成本,並且也應該“更快”,因為N足夠大,而你可能沒有。

你的大部分實現都是創建了很多子矩陣,我的猜測就是你存儲它們的方式,你每次執行此操作都需要分配內存和復制。 有一些“切片”矩陣和邏輯轉置矩陣,如果你可以幫助你優化可能是你的過程中最慢的部分。

我對Stassen multiplcation實現的速度有多快感到震驚:

http://ezekiel.vancouver.wsu.edu/~cs330/lectures/linear_algebra/mm/mm.c

當n = 1024時,我的機器上的速度提高了近16倍。 我可以解釋這個加速的唯一方法是我的算法更加緩存友好 - 也就是說,它專注於矩陣的一小部分,因此數據更加本地化。

C ++實現中的開銷可能過高 - 編譯器生成的臨時數超過了實際需要的數量。 我的實現試圖通過盡可能重用內存來最小化這個。

遠射,但您是否認為標准乘法可能會被編譯器優化? 你能關閉優化嗎​​?

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM