繁体   English   中英

使用具有numpy矩阵的Strassen算法的错误输出矩阵

[英]Incorrect output matrix using Strassen's algorithm with numpy matrices

我正在尝试使用Python 3和numpy矩阵实现CLRS中所述的Strassen矩阵乘法算法。

问题在于,输出矩阵C作为零矩阵而不是正确的乘积返回。 我不确定为什么我的实现无法正常工作,但是怀疑这与每次递归调用中创建C矩阵有关。 对于我做错了什么以及如何解决它的任何解释,我将不胜感激。

谢谢!

import numpy as np

def strassen(A,B):
    n = A.shape[0]
    C = np.zeros((n*n), dtype=np.int).reshape(n,n)
    if n == 1:
        C[0][0] = A[0][0] * B[0][0]

    else:
        k = int(n/2) 

        A11,A21,A12,A22 = A[:k,:k], A[k:, :k], A[:k, k:], A[k:, k:]
        B11,B21,B12,B22 = B[:k,:k], B[k:, :k], B[:k, k:], B[k:, k:]
        C11,C21,C12,C22 = C[:k,:k], C[k:, :k], C[:k, k:], C[k:, k:]

        S1 = B12 - B22
        S2 = A11 + A12
        S3 = A21 + A22
        S4 = B21 - B11
        S5 = A11 + A22
        S6 = B11 + B22
        S7 = A12 - A22
        S8 = B21 + B22
        S9 = A11 - A21
        S10= B11 + B12

        P1 = strassen(A11, S1)
        P2 = strassen(S2, B22)
        P3 = strassen(S3, B11)
        P4 = strassen(A22, S4)
        P5 = strassen(S5, S6)
        P6 = strassen(S7, S8)
        P7 = strassen(S9, S10)

        C11 = P5 + P4 - P2 + P6
        C12 = P1 + P2
        C21 = P3 + P4
        C22 = P5 + P1 - P3 - P7


    return C

好吧,我可以通过简单地用新值更新切片C [:k,:k]而不是创建新变量C11,C12..ect来使其工作。 因为这样做会创建一个新矩阵,而不是对原始矩阵C的引用。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM