简体   繁体   中英

Memory management for Strassen's matrix multiplication

As a part of an assignment, I am trying to find out the crossover point for Strassen's matrix multiplication and naive multiplication algorithms. But for the same, I am unable to proceed when matrix becomes 256x256. Can someone please suggest me the appropriate memory management technique to be able to handle larger inputs.

The code is in C as follows:

#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<time.h>

void strassenMul(double* X, double* Y, double* Z, int m);
void matMul(double* A, double* B, double* C, int n);
void matAdd(double* A, double* B, double* C, int m);
void matSub(double* A, double* B, double* C, int m);

int idx = 0;

int main()
{
    int N;
    int count = 0;
    int i, j;
    clock_t start, end;
    double elapsed;

    int total = 15;
    double tnaive[total];
    double tstrassen[total];
    printf("-------------------------------------------------------------------------\n\n");
    for (count = 0; count < total; count++)
    {
        N = pow(2, count);
        printf("Matrix size = %2d\t",N);
        double X[N][N], Y[N][N], Z[N][N], W[N][N];
        srand(time(NULL));
        for (i = 0; i < N; i++)
        {
            for (j = 0; j < N; j++)
            {
                X[i][j] = rand()/(RAND_MAX + 1.);
                Y[i][j] = rand()/(RAND_MAX + 1.);
            }
        }
        start = clock();
        matMul((double *)X, (double *)Y, (double *)W, N);
        end = clock();
        elapsed = ((double) (end - start))*100/ CLOCKS_PER_SEC;
        tnaive[count] = elapsed;
        printf("naive = %5.4f\t\t",tnaive[count]);

        start = clock();
        strassenMul((double *)X, (double *)Y, (double *)Z, N);
        end = clock();
        elapsed = ((double) (end - start))*100/ CLOCKS_PER_SEC;
        tstrassen[count] = elapsed;
        printf("strassen = %5.4f\n",tstrassen[count]);
    }
    printf("-------------------------------------------------------------------\n\n\n");

    while (tnaive[idx+1] <= tstrassen[idx+1] && idx < 14) idx++;

    printf("Optimum input size to switch from normal multiplication to Strassen's is above %d\n\n", idx);

    printf("Please enter the size of array as a power of 2\n");
    scanf("%d",&N);
    double A[N][N], B[N][N], C[N][N];
    srand(time(NULL));
    for (i = 0; i < N; i++)
    {
        for (j = 0; j < N; j++)
        {
            A[i][j] = rand()/(RAND_MAX + 1.);
            B[i][j] = rand()/(RAND_MAX + 1.);
        }
    }

printf("------------------- Input Matrices A and B ---------------------------\n\n");
    for (i = 0; i < N; i++)
    {
        for (j = 0; j < N; j++)
            printf("%5.4f  ",A[i][j]);
        printf("\n");
    }
    printf("\n");

    for (i = 0; i < N; i++)
    {
        for (j = 0; j < N; j++)
            printf("%5.4f  ",B[i][j]);
        printf("\n");
    }
    printf("\n------- Output matrix by Strassen's method after optimization -----------\n\n");

    strassenMul((double *)A, (double *)B, (double *)C, N);

    for (i = 0; i < N; i++)
    {
        for (j = 0; j < N; j++)
            printf("%5.4f  ",C[i][j]);
        printf("\n");
    }
    return(0);
}

void strassenMul(double *X, double *Y, double *Z, int m)
{
    if (m <= idx)
    {
        matMul((double *)X, (double *)Y, (double *)Z, m);
        return;
    }
    if (m == 1)
    {
        *Z = *X * *Y;
        return;
    }
    int row = 0, col = 0;
    int n = m/2;
    int i = 0, j = 0;
    double x11[n][n], x12[n][n], x21[n][n], x22[n][n];
    double y11[n][n], y12[n][n], y21[n][n], y22[n][n];
    double P1[n][n], P2[n][n], P3[n][n], P4[n][n], P5[n][n], P6[n][n], P7[n][n];
    double C11[n][n], C12[n][n], C21[n][n], C22[n][n];
    double S1[n][n], S2[n][n], S3[n][n], S4[n][n], S5[n][n], S6[n][n], S7[n][n];
    double S8[n][n], S9[n][n], S10[n][n], S11[n][n], S12[n][n], S13[n][n], S14[n][n];

    for (row = 0, i = 0; row < n; row++, i++)
    {
        for (col = 0, j = 0; col < n; col++, j++)
        {
            x11[i][j] = *((X+row*m)+col);
            y11[i][j] = *((Y+row*m)+col);
        }
        for (col = n, j = 0; col < m; col++, j++)
        {
            x12[i][j] = *((X+row*m)+col);
            y12[i][j] = *((Y+row*m)+col);
        }
    }

    for (row = n, i = 0; row < m; row++, i++)
    {
        for (col = 0, j = 0; col < n; col++, j++)
        {
            x21[i][j] = *((X+row*m)+col);
            y21[i][j] = *((Y+row*m)+col);
        }
        for (col = n, j = 0; col < m; col++, j++)
        {
            x22[i][j] = *((X+row*m)+col);
            y22[i][j] = *((Y+row*m)+col);
        }
    }

    // Calculating P1
    matAdd((double *)x11, (double *)x22, (double *)S1, n);
    matAdd((double *)y11, (double *)y22, (double *)S2, n);
    strassenMul((double *)S1, (double *)S2, (double *)P1, n);

    // Calculating P2
    matAdd((double *)x21, (double *)x22, (double *)S3, n);
    strassenMul((double *)S3, (double *)y11, (double *)P2, n);

    // Calculating P3
    matSub((double *)y12, (double *)y22, (double *)S4, n);
    strassenMul((double *)x11, (double *)S4, (double *)P3, n);

    // Calculating P4
    matSub((double *)y21, (double *)y11, (double *)S5, n);
    strassenMul((double *)x22, (double *)S5, (double *)P4, n);

    // Calculating P5
    matAdd((double *)x11, (double *)x12, (double *)S6, n);
    strassenMul((double *)S6, (double *)y22, (double *)P5, n);

    // Calculating P6
    matSub((double *)x21, (double *)x11, (double *)S7, n);
    matAdd((double *)y11, (double *)y12, (double *)S8, n);
    strassenMul((double *)S7, (double *)S8, (double *)P6, n);

    // Calculating P7
    matSub((double *)x12, (double *)x22, (double *)S9, n);
    matAdd((double *)y21, (double *)y22, (double *)S10, n);
    strassenMul((double *)S9, (double *)S10, (double *)P7, n);

    // Calculating C11
    matAdd((double *)P1, (double *)P4, (double *)S11, n);
    matSub((double *)S11, (double *)P5, (double *)S12, n);
    matAdd((double *)S12, (double *)P7, (double *)C11, n);

    // Calculating C12
    matAdd((double *)P3, (double *)P5, (double *)C12, n);

    // Calculating C21
    matAdd((double *)P2, (double *)P4, (double *)C21, n);

    // Calculating C22
    matAdd((double *)P1, (double *)P3, (double *)S13, n);
    matSub((double *)S13, (double *)P2, (double *)S14, n);
    matAdd((double *)S14, (double *)P6, (double *)C22, n);

    for (row = 0, i = 0; row < n; row++, i++)
    {
        for (col = 0, j = 0; col < n; col++, j++)
            *((Z+row*m)+col) = C11[i][j];
        for (col = n, j = 0; col < m; col++, j++)
            *((Z+row*m)+col) = C12[i][j];
    }
    for (row = n, i = 0; row < m; row++, i++)
    {
        for (col = 0, j = 0; col < n; col++, j++)
            *((Z+row*m)+col) = C21[i][j];
        for (col = n, j = 0; col < m; col++, j++)
            *((Z+row*m)+col) = C22[i][j];
    }
}

void matMul(double *A, double *B, double *C, int n)
{
    int i = 0, j = 0, k = 0, row = 0, col = 0;
    double sum;
    for (i = 0; i < n; i++)
    {
        for (j = 0; j < n; j++)
        {
            sum = 0.0;
            for (k = 0; k < n; k++)
            {
                sum += *((A+i*n)+k) * *((B+k*n)+j);
            }
            *((C+i*n)+j) = sum;
        }
    }
}

void matAdd(double *A, double *B, double *C, int m)
{
    int row = 0, col = 0;
    for (row = 0; row < m; row++)
        for (col = 0; col < m; col++)
            *((C+row*m)+col) = *((A+row*m)+col) + *((B+row*m)+col);
}

void matSub(double *A, double *B, double *C, int m)
{
    int row = 0, col = 0;
    for (row = 0; row < m; row++)
        for (col = 0; col < m; col++)
            *((C+row*m)+col) = *((A+row*m)+col) - *((B+row*m)+col);
}

Added later If I try using malloc statements for memory assignment, the code is as follows. But the problem is that it stops after the naive matrix multiplication method and does not even proceed to the Strassen's method for N=1. It shows a prompt to close the program.

for (count = 0; count < total; count++)
{
    N = pow(2, count);
    printf("Matrix size = %2d\t",N);
    //double X[N][N], Y[N][N], Z[N][N], W[N][N];
    double **X, **Y, **Z, **W;
    X = malloc(N * sizeof(double*));
    if (X == NULL){
        perror("Failed malloc() in X");
        return 1;
    }
    Y = malloc(N * sizeof(double*));
            if (Y == NULL){
                perror("Failed malloc() in Y");
                return 1;
    }
    Z = malloc(N * sizeof(double*));
            if (Z == NULL){
                perror("Failed malloc() in Z");
                return 1;
    }
    W = malloc(N * sizeof(double*));
            if (W == NULL){
                perror("Failed malloc() in W");
                return 1;
    }
    for (j = 0; j < N; j++)
    {
        X[j] = malloc(N * sizeof(double*));
        if (X[j] == NULL){
            perror("Failed malloc() in X[j]");
            return 1;
        }
        Y[j] = malloc(N * sizeof(double*));
                    if (Y[j] == NULL){
                        perror("Failed malloc() in Y[j]");
                        return 1;
        }
        Z[j] = malloc(N * sizeof(double*));
                    if (Z[j] == NULL){
                        perror("Failed malloc() in Z[j]");
                        return 1;
        }
        W[j] = malloc(N * sizeof(double*));
                    if (W[j] == NULL){
                        perror("Failed malloc() in W[j]");
                        return 1;
        }
    }
    srand(time(NULL));
    for (i = 0; i < N; i++)
    {
        for (j = 0; j < N; j++)
        {
            X[i][j] = rand()/(RAND_MAX + 1.);
            Y[i][j] = rand()/(RAND_MAX + 1.);
        }
    }
    start = clock();
    matMul((double *)X, (double *)Y, (double *)W, N);
    end = clock();
    elapsed = ((double) (end - start))*100/ CLOCKS_PER_SEC;
    tnaive[count] = elapsed;
    printf("naive = %5.4f\t\t",tnaive[count]);

    start = clock();
    strassenMul((double *)X, (double *)Y, (double *)Z, N);
    end = clock();
    elapsed = ((double) (end - start))*100/ CLOCKS_PER_SEC;
    tstrassen[count] = elapsed;
    for (j = 0; j < N; j++)
    {
        free(X[j]);
        free(Y[j]);
        free(Z[j]);
        free(W[j]);
    }

    free(X); free(Y); free(Z); free(W);

    printf("strassen = %5.4f\n",tstrassen[count]);
}

I have re-written the answer. My previous answer which allocated memory row by row won't work, because OP has cast the 2-D arrays to 1-D arrays when passed to the functions. Here is my re-write of the code with some simplifications, such as keeping all the matrix arrays 1-dimensional.

I am unsure exactly what Strassen's method does, although the recursion halves the matrix dimensions. So I do wonder if the intention was to use row*2 and col*2 when accessing the arrays passed.

I hope the techniques are useful to you - even that it works! All the matrix arrays are now on the heap.

#include<stdio.h>
#include<stdlib.h>
#include<math.h>
#include<time.h>

#define total   4       //15

void strassenMul(double* X, double* Y, double* Z, int m);
void matMul(double* A, double* B, double* C, int n);
void matAdd(double* A, double* B, double* C, int m);
void matSub(double* A, double* B, double* C, int m);

enum array { x11, x12, x21, x22, y11, y12, y21, y22,
    P1, P2, P3, P4, P5, P6, P7, C11, C12, C21, C22,
    S1, S2, S3, S4, S5, S6, S7, S8, S9, S10, S11, S12, S13, S14, arrs };

int idx = 0;

int main()
{
    int N;
    int count = 0;
    int i, j;
    clock_t start, end;
    double elapsed;

    double tnaive[total];
    double tstrassen[total];
    double *X, *Y, *Z, *W, *A, *B, *C;

    printf("-------------------------------------------------------------------------\n\n");
    for (count = 0; count < total; count++)
    {
        N = (int)pow(2, count);
        printf("Matrix size = %2d\t",N);
        X = malloc(N*N*sizeof(double));
        Y = malloc(N*N*sizeof(double));
        Z = malloc(N*N*sizeof(double));
        W = malloc(N*N*sizeof(double));
        if (X==NULL || Y==NULL || Z==NULL || W==NULL) {
            printf("Out of memory (1)\n");
            return 1;
        }
        srand((unsigned)time(NULL));
        for (i=0; i<N*N; i++)
        {
            X[i] = rand()/(RAND_MAX + 1.);
            Y[i] = rand()/(RAND_MAX + 1.);
        }
        start = clock();
        matMul(X, Y, W, N);
        end = clock();
        elapsed = ((double) (end - start))*100/ CLOCKS_PER_SEC;
        tnaive[count] = elapsed;
        printf("naive = %5.4f\t\t",tnaive[count]);

        start = clock();
        strassenMul(X, Y, Z, N);
        free(W); 
        free(Z); 
        free(Y); 
        free(X); 
        end = clock();
        elapsed = ((double) (end - start))*100/ CLOCKS_PER_SEC;
        tstrassen[count] = elapsed;
        printf("strassen = %5.4f\n",tstrassen[count]);
    }
    printf("-------------------------------------------------------------------\n\n\n");

    while (tnaive[idx+1] <= tstrassen[idx+1] && idx < 14) idx++;

    printf("Optimum input size to switch from normal multiplication to Strassen's is above %d\n\n", idx);

    printf("Please enter the size of array as a power of 2\n");
    scanf("%d",&N);
    A = malloc(N*N*sizeof(double));
    B = malloc(N*N*sizeof(double));
    C = malloc(N*N*sizeof(double));
    if (A==NULL || B==NULL || C==NULL) {
        printf("Out of memory (2)\n");
        return 1;
    }
    srand((unsigned)time(NULL));
    for (i=0; i<N*N; i++)
    {
        A[i] = rand()/(RAND_MAX + 1.);
        B[i] = rand()/(RAND_MAX + 1.);
    }

    printf("------------------- Input Matrices A and B ---------------------------\n\n");
    for (i = 0; i < N; i++)
    {
        for (j = 0; j < N; j++)
            printf("%5.4f  ",A[i*N+j]);
        printf("\n");
    }
    printf("\n");

    for (i = 0; i < N; i++)
    {
        for (j = 0; j < N; j++)
            printf("%5.4f  ",B[i*N+j]);
        printf("\n");
    }
    printf("\n------- Output matrix by Strassen's method after optimization -----------\n\n");

    strassenMul(A, B, C, N);

    for (i = 0; i < N; i++)
    {
        for (j = 0; j < N; j++)
            printf("%5.4f  ",C[i*N+j]);
        printf("\n");
    }
    free(C); 
    free(B); 
    free(A); 
    return(0);
}

void strassenMul(double *X, double *Y, double *Z, int m)
{
    int row = 0, col = 0;
    int n = m/2;
    int i = 0, j = 0;
    double *arr[arrs];                      // each matrix mem ptr

    if (m <= idx)
    {
        matMul(X, Y, Z, m);
        return;
    }
    if (m == 1)
    {
        *Z = *X * *Y;
        return;
    }

    for (i=0; i<arrs; i++) {                // memory for arrays
        arr[i] = malloc(n*n*sizeof(double));
        if (arr[i] == NULL) {
            printf("Out of memory (1)\n");
            exit (1);                       // brutal
        }
    }

    for (row = 0, i = 0; row < n; row++, i++)
    {
        for (col = 0, j = 0; col < n; col++, j++)
        {
            arr[x11][i*n+j] = X[row*m+col];
            arr[y11][i*n+j] = Y[row*m+col];
        }
        for (col = n, j = 0; col < m; col++, j++)
        {
            arr[x12][i*n+j] = X[row*m+col];
            arr[y12][i*n+j] = Y[row*m+col];
        }
    }

    for (row = n, i = 0; row < m; row++, i++)
    {
        for (col = 0, j = 0; col < n; col++, j++)
        {
            arr[x21][i*n+j] = X[row*m+col];
            arr[y21][i*n+j] = Y[row*m+col];
        }
        for (col = n, j = 0; col < m; col++, j++)
        {
            arr[x22][i*n+j] = X[row*m+col];
            arr[y22][i*n+j] = Y[row*m+col];
        }
    }

    // Calculating P1
    matAdd(arr[x11], arr[x22], arr[S1], n);
    matAdd(arr[y11], arr[y22], arr[S2], n);
    strassenMul(arr[S1], arr[S2], arr[P1], n);

    // Calculating P2
    matAdd(arr[x21], arr[x22], arr[S3], n);
    strassenMul(arr[S3], arr[y11], arr[P2], n);

    // Calculating P3
    matSub(arr[y12], arr[y22], arr[S4], n);
    strassenMul(arr[x11], arr[S4], arr[P3], n);

    // Calculating P4
    matSub(arr[y21], arr[y11], arr[S5], n);
    strassenMul(arr[x22], arr[S5], arr[P4], n);

    // Calculating P5
    matAdd(arr[x11], arr[x12], arr[S6], n);
    strassenMul(arr[S6], arr[y22], arr[P5], n);

    // Calculating P6
    matSub(arr[x21], arr[x11], arr[S7], n);
    matAdd(arr[y11], arr[y12], arr[S8], n);
    strassenMul(arr[S7], arr[S8], arr[P6], n);

    // Calculating P7
    matSub(arr[x12], arr[x22], arr[S9], n);
    matAdd(arr[y21], arr[y22], arr[S10], n);
    strassenMul(arr[S9], arr[S10], arr[P7], n);

    // Calculating C11
    matAdd(arr[P1], arr[P4], arr[S11], n);
    matSub(arr[S11], arr[P5], arr[S12], n);
    matAdd(arr[S12], arr[P7], arr[C11], n);

    // Calculating C12
    matAdd(arr[P3], arr[P5], arr[C12], n);

    // Calculating C21
    matAdd(arr[P2], arr[P4], arr[C21], n);

    // Calculating C22
    matAdd(arr[P1], arr[P3], arr[S13], n);
    matSub(arr[S13], arr[P2], arr[S14], n);
    matAdd(arr[S14], arr[P6], arr[C22], n);

    for (row = 0, i = 0; row < n; row++, i++)
    {
        for (col = 0, j = 0; col < n; col++, j++)
            Z[row*m+col] = arr[C11][i*n+j];
        for (col = n, j = 0; col < m; col++, j++)
            Z[row*m+col] = arr[C12][i*n+j];
    }
    for (row = n, i = 0; row < m; row++, i++)
    {
        for (col = 0, j = 0; col < n; col++, j++)
            Z[row*m+col] = arr[C21][i*n+j];
        for (col = n, j = 0; col < m; col++, j++)
            Z[row*m+col] = arr[C22][i*n+j];
    }

    for (i=0; i<arrs; i++)
        free (arr[i]);
}

void matMul(double *A, double *B, double *C, int n)
{
    int i = 0, j = 0, k = 0, row = 0, col = 0;
    double sum;
    for (i = 0; i < n; i++)
    {
        for (j = 0; j < n; j++)
        {
            sum = 0.0;
            for (k = 0; k < n; k++)
            {
                sum += A[i*n+k] * B[k*n+j];
            }
            C[i*n+j] = sum;
        }
    }
}

void matAdd(double *A, double *B, double *C, int m)
{
    int row = 0, col = 0;
    for (row = 0; row < m; row++)
        for (col = 0; col < m; col++)
            C[row*m+col] = A[row*m+col] + B[row*m+col];
}

void matSub(double *A, double *B, double *C, int m)
{
    int row = 0, col = 0;
    for (row = 0; row < m; row++)
        for (col = 0; col < m; col++)
            C[row*m+col] = A[row*m+col] - B[row*m+col];
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM