简体   繁体   中英

Using inline assembly to speed up Matrix multiplication

I have been trying to speed up matrix-matrix multiplication C <- C + alpha * A * B via register blocking, SSE2 vectorization and L1 cache blocking (note that I have specially chosen the transpose setting op(A)=A and op(B)=B). After some effort my written code is still about 50% slower than GotoBLAS in single thread mode .

The following is my code for the "kernel" square matrix-matrix multiplication on L1 cache, called "DGEBB" (general block-block operation) in Goto's work, that multiplies two NB*NB square matrices (NB restricted to be a multiple of 4). I have examined its assembly output under GCC 4.8, realizing that the compiler is not doing a good job in scheduling the unrolled innermost loop: kk-loop. What I hope is that the compiler optimizes register allocation to attain register reuse, and schedules the computation interleaving multiplication, addition and memory operation for pipelining; however, the compiler failed to do this. For this reason, I would like to replace the innermost loop by some inline assembly .

I am completely new to x86 assembly. Though having read around for GCC's extended asm for hours, I am still not sure how to do it properly. I have attached a stupid version I could write at my best, yet knowing it is wrong. This version is modified from the compiler's original assembly output for the kk-loop. As I know how to allocate register using "movl", "movapd", etc, I have re-arranged the computation in the order I fancy. But It does not work yet. 1) It seems to me that registers %eax, %ebx, %ecx are used both inside and outside the assembly which is nasty. 2) Also, the way I pass the input and output operands does not work. 3) Finally, I really want a version that the whole kk-loop can be inlined. Thanks if someone could helps me out!

The C code for DGEBB (called DGEBB_SSE2_x86, as my laptop is 32-bit x86 machine, with SSE2 - SSE4.1 support):

#include <stdint.h>  /* type define of "uintptr_t" */
#include <emmintrin.h>  /* double precision computation support since SSE2 */
#include <R.h>  /* use R's error handling error() */

void DGEBB_SSE2_x86 (int *NB, double *ALPHA, double *A, double *B, double *C) {
/* check "nb", must be a multiple of 4 */
int TWO=2, FOUR=4, nb=*NB; if (nb%FOUR) error("error in DGEBB_SSE2_x86: nb is not a multiple of 4!\n");
/* check memory alignment of A, B, C, 16 Byte alignment is mandatory (as XMM registers are 128-bit in length) */
uintptr_t sixteen_bytes=0xF;
if ((uintptr_t)A & sixteen_bytes) error("error in DGEBB_SSE2_x86: A is not 16 Bytes aligned in memory!");
if ((uintptr_t)B & sixteen_bytes) error("error in DGEBB_SSE2_x86: B is not 16 Bytes aligned in memory!");
if ((uintptr_t)C & sixteen_bytes) error("error in DGEBB_SSE2_x86: C is not 16 Bytes aligned in memory!");
/* define vector variables */
__m128d C1_vec_reg=_mm_setzero_pd(), C2_vec_reg=C1_vec_reg, C3_vec_reg=C1_vec_reg, C4_vec_reg=C1_vec_reg,A1_vec_reg, A2_vec_reg, B_vec_reg, U_vec_reg;
/* define scalar variables */
int jj, kk, ii, nb2=nb+nb, nb_half=nb/TWO;
double *B1_copy, *B1, *C1, *a, *b, *c, *c0;
/* start triple loop nest */
C1=C;B1=B;  /* initial column tile of C and B */
jj=nb_half;
while (jj--) {
  c=C1;B1_copy=B1;C1+=nb2;B1+=nb2;b=B1_copy;
  for (ii=0; ii<nb; ii+=FOUR) {
    a=A+ii;b=B1_copy;
    kk=nb_half;
    while (kk--) {
    /* [kernel] amortize pointer arithmetic! */
    A1_vec_reg=_mm_load_pd(a);  /* [fetch] */
    B_vec_reg=_mm_load1_pd(b);  /* [fetch] */
    U_vec_reg=_mm_mul_pd(A1_vec_reg,B_vec_reg);C1_vec_reg=_mm_add_pd(C1_vec_reg,U_vec_reg);  /* [daxpy] */
    A2_vec_reg=_mm_load_pd(a+TWO);a+=nb;  /* [fetch] */
    U_vec_reg=_mm_mul_pd(A2_vec_reg,B_vec_reg);C2_vec_reg=_mm_add_pd(C2_vec_reg,U_vec_reg);  /* [daxpy] */
    B_vec_reg=_mm_load1_pd(b+nb);b++;  /* [fetch] */
    U_vec_reg=_mm_mul_pd(A1_vec_reg,B_vec_reg);C3_vec_reg=_mm_add_pd(C3_vec_reg,U_vec_reg);  /* [daxpy] */
    A1_vec_reg=_mm_load_pd(a);  /* [fetch] */
    U_vec_reg=_mm_mul_pd(A2_vec_reg,B_vec_reg);C4_vec_reg=_mm_add_pd(C4_vec_reg,U_vec_reg);  /* [daxpy]*/
    B_vec_reg=_mm_load1_pd(b);  /* [fetch] */
    U_vec_reg=_mm_mul_pd(A1_vec_reg,B_vec_reg);C1_vec_reg=_mm_add_pd(C1_vec_reg,U_vec_reg);  /* [daxpy] */
    A2_vec_reg=_mm_load_pd(a+TWO);a+=nb;  /* [fetch] */
    U_vec_reg=_mm_mul_pd(A2_vec_reg,B_vec_reg);C2_vec_reg=_mm_add_pd(C2_vec_reg,U_vec_reg);  /* [daxpy] */
    B_vec_reg=_mm_load1_pd(b+nb);b++;  /* [fetch] */
    U_vec_reg=_mm_mul_pd(A1_vec_reg,B_vec_reg);C3_vec_reg=_mm_add_pd(C3_vec_reg,U_vec_reg);  /* [daxpy] */
    U_vec_reg=_mm_mul_pd(A2_vec_reg,B_vec_reg);C4_vec_reg=_mm_add_pd(C4_vec_reg,U_vec_reg);  /* [daxpy] */
    }  /* [end of kk-loop] */
  /* [write-back] amortize pointer arithmetic! */
  A2_vec_reg=_mm_load1_pd(ALPHA);
  U_vec_reg=_mm_load_pd(c);c0=c+nb;C1_vec_reg=_mm_mul_pd(C1_vec_reg,A2_vec_reg);  /* [fetch] */
  A1_vec_reg=U_vec_reg;C1_vec_reg=_mm_add_pd(C1_vec_reg,A1_vec_reg);U_vec_reg=_mm_load_pd(c0);  /* [fetch] */
  C3_vec_reg=_mm_mul_pd(C3_vec_reg,A2_vec_reg);_mm_store_pd(c,C1_vec_reg);c+=TWO;  /* [store] */
  A1_vec_reg=U_vec_reg;C3_vec_reg=_mm_add_pd(C3_vec_reg,A1_vec_reg);U_vec_reg=_mm_load_pd(c);  /* [fetch] */
  C2_vec_reg=_mm_mul_pd(C2_vec_reg,A2_vec_reg);_mm_store_pd(c0,C3_vec_reg);c0+=TWO;  /* [store] */
  A1_vec_reg=U_vec_reg;C2_vec_reg=_mm_add_pd(C2_vec_reg,A1_vec_reg);U_vec_reg=_mm_load_pd(c0);  /* [fetch] */
  C4_vec_reg=_mm_mul_pd(C4_vec_reg,A2_vec_reg);_mm_store_pd(c,C2_vec_reg);c+=TWO;  /* [store] */
  C4_vec_reg=_mm_add_pd(C4_vec_reg,U_vec_reg);_mm_store_pd(c0,C4_vec_reg);  /* [store] */
  C1_vec_reg=_mm_setzero_pd();C3_vec_reg=C1_vec_reg;C2_vec_reg=C1_vec_reg;C4_vec_reg=C1_vec_reg;
  }  /* [end of ii-loop] */
}  /* [end of jj-loop] */
}

My stupid version of inline assembly for the kk-loop is here:

      while (kk--) {
    asm("movapd %0, %%xmm3\n\t"     /* C1_vec_reg -> xmm3 */
        "movapd %1, %%xmm1\n\t"     /* C2_vec_reg -> xmm1 */
        "movapd %2, %%xmm2\n\t"     /* C3_vec_reg -> xmm2 */
        "movapd %3, %%xmm0\n\t"     /* C4_vec_reg -> xmm0 */
        "movl %4, %%eax\n\t"    /* pointer a -> %eax */
        "movl %5, %%edx\n\t"    /* pointer b -> %edx */
        "movl %6, %%ecx\n\t"    /* block size nb -> %ecx */
        "movapd (%%eax), %%xmm5\n\t"   /* A1_vec_reg -> xmm5 */
    "movsd (%%edx), %%xmm4\n\t"        /* B_vec_reg -> xmm4 */
    "unpcklpd %%xmm4, %%xmm4\n\t"
        "movapd %%xmm5, %%xmm6\n\t"        /* xmm5 -> xmm6 */
        "mulpd %%xmm4, %%xmm6\n\t"        /* xmm6 *= xmm4 */
    "addpd %%xmm6, %%xmm3\n\t"        /* xmm3 += xmm6 */
        "movapd 16(%%eax), %%xmm7\n\t"        /* A2_vec_reg -> xmm7 */
        "movapd %%xmm7, %%xmm6\n\t"        /* xmm7 -> xmm6 */
        "mulpd %%xmm4, %%xmm6\n\t"        /* xmm6 *= xmm4 */
    "addpd %%xmm6, %%xmm1\n\t"        /* xmm1 += xmm6 */
    "movsd (%%edx,%%ecx), %%xmm4\n\t"        /* B_vec_reg -> xmm4 */
    "addl $8, %%edx\n\t"          /* b++ */
    "movsd (%%edx), %%xmm4\n\t"       /* B_vec_reg -> xmm4 */
    "unpcklpd %%xmm4, %%xmm4\n\t"
        "movapd %%xmm5, %%xmm6\n\t"        /* xmm5 -> xmm6 */
        "mulpd %%xmm4, %%xmm6\n\t"        /* xmm6 *= xmm4 */
    "addpd %%xmm6, %%xmm2\n\t"        /* xmm2 += xmm6 */
    "addl %%ecx, %%eax\n\t"          /* a+=nb */
    "movapd (%%eax), %%xmm5\n\t"        /* A1_vec_reg -> xmm5 */
        "movapd %%xmm7, %%xmm6\n\t"        /* xmm7 -> xmm6 */
        "mulpd %%xmm4, %%xmm6\n\t"        /* xmm6 *= xmm4 */
    "addpd %%xmm6, %%xmm0\n\t"      /* xmm0 += xmm6 */
    "movsd (%%edx), %%xmm4\n\t"        /* B_vec_reg -> xmm4 */
    "unpcklpd %%xmm4, %%xmm4\n\t"
        "movapd %%xmm5, %%xmm6\n\t"        /* xmm5 -> xmm6 */
        "mulpd %%xmm4, %%xmm6\n\t"        /* xmm6 *= xmm4 */
        "addpd %%xmm6, %%xmm3\n\t"        /* xmm3 += xmm6 */
        "movapd 16(%%eax), %%xmm7\n\t"        /* A2_vec_reg -> xmm7 */
        "movapd %%xmm7, %%xmm6\n\t"        /* xmm7 -> xmm6 */
        "mulpd %%xmm4, %%xmm6\n\t"        /* xmm6 *= xmm4 */
    "addpd %%xmm6, %%xmm1\n\t"        /* xmm1 += xmm6 */
    "movsd (%%edx,%%ecx), %%xmm4\n\t"        /* B_vec_reg -> xmm4 */
    "addl $8, %%edx\n\t"          /* b++ */
    "movsd (%%edx), %%xmm4\n\t"       /* B_vec_reg -> xmm4 */
    "unpcklpd %%xmm4, %%xmm4\n\t"
        "movapd %%xmm5, %%xmm6\n\t"        /* xmm5 -> xmm6 */
        "mulpd %%xmm4, %%xmm6\n\t"        /* xmm6 *= xmm4 */
    "addpd %%xmm6, %%xmm2\n\t"        /* xmm2 += xmm6 */
        "movapd %%xmm7, %%xmm6\n\t"        /* xmm7 -> xmm6 */
        "mulpd %%xmm4, %%xmm6\n\t"        /* xmm6 *= xmm4 */
    "addpd %%xmm6, %%xmm0\n\t"      /* xmm0 += xmm6 */
    "addl %%ecx, %%eax"
        : "+x"(C1_vec_reg), "+x"(C2_vec_reg), "+x"(C3_vec_reg), "+x"(C4_vec_reg), "+m"(a), "+m"(b)
        : "x"(C1_vec_reg), "x"(C2_vec_reg), "x"(C3_vec_reg), "x"(C4_vec_reg), "4"(a), "5"(b), "rm"(nb)); 
}

Here is some explanation of the code:

Unrolling out loops to expose a micro "dger" kernel for register resue:
 (c11 c12) += (a1) * (b1 b2)
 (c21 c22)    (a2)
 (c31 c32)    (a3)
 (c41 c42)    (a4)
This can be implemented as 4 vectorized "daxpy":
 (c11) += (a1) * (b1)  ,  (c31) += (a3) * (b1)  ,  (c12) += (a1) * (b2)  ,  (c32) += (a3) * (b2)  .
 (c21)    (a2)   (b1)     (c41)    (a4)   (b1)     (c22)    (a2)   (b2)     (c42)    (a4)   (b2)
4 micor C-vectors are held constantly in XMM registers named C1_vec_reg, C2_vec_reg, C3_vec_reg, C4_vec_reg.
2 micro A-vectors are loaded into XMM registers named A1_vec_reg, A2_vec_reg.
2 micro B-vectors can reuse a single XMM register named B_vec_reg.
1 additional XMM register, U_vec_reg, will store temporary values.
The above scheduling exploits all 8 XMM registers on x84 architectures with SIMD unit, and each XMM is used twice after loaded.

PS: I am an R user from stats group. The header file enables the use of R's error handling functionality error(). This will just terminate C program rather than the whole R process. If you do not use R, delete this line and corresponding lines in the code.

This is an old problem back to the early phase of development of my HPC Cholesky factorization routine. The C code is outdated, and the assembly is naively incorrect. Later posts follow this thread.

(inline assembly in C) Assembler messages: Error: unknown pseudo-op: gives a correct implementation of inline assembly.

How to ask GCC to completely unroll this loop (ie, peel this loop)? gives better C code.

When writing GCC inline assembly, cares need to paid to potential changes of status flag. (inline assembly in C) Funny memory segmentation fault is a lesson for me.

Vectorization is key to HPC. SSE instruction MOVSD (extended: floating point scalar & vector operations on x86, x86-64) contains some discussion on Intel SSE2/3, while FMA instruction _mm256_fmadd_pd(): "132", "231" and "213"? has some information on Intel AVX's FMA instruction.

Surely all these are only related to computational kernels. There are a lot of other work related to how everything are wrapped up for a final high performance Cholesky factorization routine. The performance of the first release of my routine is in Why can't my CPU maintain peak performance in HPC .

Currently I am upgrading the kernel routine for even higher performance. Possibly there will be further posts on this thread. Thanks to stack overflow community, especially Z boson , Peter Cordes and nominal animal for answering various my questions. I learnt a lot and feel really happy in this process. [Surely at the same time, I learnt to be a better SO member.]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM