繁体   English   中英

矩阵乘法:Strassen vs. Standard

[英]Matrix multiplication: Strassen vs. Standard

我尝试使用C ++实现Strassen算法进行矩阵乘法,但结果不是我所期望的。 正如您所看到的,strassen总是花费更多时间,然后标准实现,并且只有2的幂的维度与标准实现一样快。 什么地方出了错? 替代文字

matrix mult_strassen(matrix a, matrix b) {
if (a.dim() <= cut)
    return mult_std(a, b);

matrix a11 = get_part(0, 0, a);
matrix a12 = get_part(0, 1, a);
matrix a21 = get_part(1, 0, a);
matrix a22 = get_part(1, 1, a);

matrix b11 = get_part(0, 0, b);
matrix b12 = get_part(0, 1, b);
matrix b21 = get_part(1, 0, b);
matrix b22 = get_part(1, 1, b);

matrix m1 = mult_strassen(a11 + a22, b11 + b22); 
matrix m2 = mult_strassen(a21 + a22, b11);
matrix m3 = mult_strassen(a11, b12 - b22);
matrix m4 = mult_strassen(a22, b21 - b11);
matrix m5 = mult_strassen(a11 + a12, b22);
matrix m6 = mult_strassen(a21 - a11, b11 + b12);
matrix m7 = mult_strassen(a12 - a22, b21 + b22);

matrix c(a.dim(), false, true);
set_part(0, 0, &c, m1 + m4 - m5 + m7);
set_part(0, 1, &c, m3 + m5);
set_part(1, 0, &c, m2 + m4);
set_part(1, 1, &c, m1 - m2 + m3 + m6);

return c; 
}


程序
matrix.h http://pastebin.com/TYFYCTY7
matrix.cpp http://pastebin.com/wYADLJ8Y
main.cpp http://pastebin.com/48BSqGJr

g++ main.cpp matrix.cpp -o matrix -O3

一些想法:

  • 您是否优化过它以考虑用零填充非功率的两个大小的矩阵? 我认为该算法假设您不打扰这些术语的倍增。 这就是为什么你得到的运行时间在2 ^ n和2 ^(n + 1)-1之间的平坦区域。 通过不将您知道的术语乘以零,您应该能够改进这些区域。 或许Strassen只能用于2 ^ n大小的矩阵。
  • 考虑到“大”矩阵是任意的,并且该算法仅略微优于天真的情况,O(N ^ 3)对O(N ^ 2.8)。 在尝试更大的矩阵之前,您可能看不到可衡量的收益。 例如,我做了一些有限元建模,其中10,000x10,000矩阵被认为是“小”。 很难从你的图表中看出来,但看起来在Stassen案例中511案例可能会更快。
  • 尝试使用各种优化级别进行测试,包括根本不进行优化。
  • 该算法似乎假设乘法比加法要昂贵得多。 这在40年前首次开发时确实如此,但我相信更现代的处理器,加法和乘法之间的差异变小了。 这可能会降低算法的有效性,这似乎会减少乘法,但会增加相加。
  • 你有没有看过其他一些Strassen实现的想法? 尝试对已知良好的实现进行基准测试,以确切了解您可以获得多快的速度。

好的我不是这个领域的专家,但是这里可能还有其他问题比处理速度快。 首先,strassen方法使用更多堆栈并具有更多函数调用,这增加了内存移动。 你的堆栈越大,你就会受到一定的惩罚,因为它需要从操作系统请求更大的帧。 另外,您使用动态分配,这也是一个问题。

尝试使用固定大小(带模板参数)矩阵类? 这至少会解决分配问题。

注意:我不确定它是否与您的代码一起正常运行。 您的矩阵类使用指针但没有复制构造函数或赋值运算符。 你也在最后泄漏记忆,因为你没有析构函数......

与O(N ^ 3)常规相比,Strassen的大O是O(N ^ log 7),即log 7 base 2,其略小于3。

这是您需要进行的乘法次数。

它假设你没有任何其他成本,并且也应该“更快”,因为N足够大,而你可能没有。

你的大部分实现都是创建了很多子矩阵,我的猜测就是你存储它们的方式,你每次执行此操作都需要分配内存和复制。 有一些“切片”矩阵和逻辑转置矩阵,如果你可以帮助你优化可能是你的过程中最慢的部分。

我对Stassen multiplcation实现的速度有多快感到震惊:

http://ezekiel.vancouver.wsu.edu/~cs330/lectures/linear_algebra/mm/mm.c

当n = 1024时,我的机器上的速度提高了近16倍。 我可以解释这个加速的唯一方法是我的算法更加缓存友好 - 也就是说,它专注于矩阵的一小部分,因此数据更加本地化。

C ++实现中的开销可能过高 - 编译器生成的临时数超过了实际需要的数量。 我的实现试图通过尽可能重用内存来最小化这个。

远射,但您是否认为标准乘法可能会被编译器优化? 你能关闭优化吗​​?

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM