简体   繁体   中英

multiplication using SSE (x*x*x)+(y*y*y)

I'm trying to optimize this function using SIMD but I don't know where to start.

long sum(int x,int y)
{
    return x*x*x+y*y*y;
}

The disassembled function looks like this:

  4007a0:   48 89 f2                mov    %rsi,%rdx
  4007a3:   48 89 f8                mov    %rdi,%rax
  4007a6:   48 0f af d6             imul   %rsi,%rdx
  4007aa:   48 0f af c7             imul   %rdi,%rax
  4007ae:   48 0f af d6             imul   %rsi,%rdx
  4007b2:   48 0f af c7             imul   %rdi,%rax
  4007b6:   48 8d 04 02             lea    (%rdx,%rax,1),%rax
  4007ba:   c3                      retq   
  4007bb:   0f 1f 44 00 00          nopl   0x0(%rax,%rax,1)

The calling code looks like this:

 do {
for (i = 0; i < maxi; i++) {
  j = nextj[i];
  long sum = cubeSum(i,j);
  while (sum <= p) {
    long x = sum & (psize - 1);
    int flag = table[x];
    if (flag <= guard) {
      table[x] = guard+1;
    } else if (flag == guard+1) {
      table[x] = guard+2;
      count++;
    }
    j++;
    sum = cubeSum(i,j);
  }
  nextj[i] = j;
}
p += psize;
guard += 3;
} while (p <= n);
  • Fill one SSE register with (x|y|0|0) (since each SSE register holds 4 32-bit elements). Lets call it r1
  • then make a copy of that register to another register r2
  • Do r2 * r1, storing the result in, say r2.
  • Do r2 * r1 again storing the result in r2
  • Now in r2 you have (x*x*x|y*y*y|0|0)
  • Unpack the lower two elements of r2 into separate registers, add them (SSE3 has horizontal add instructions, but only for floats and doubles).

In the end, I'd actually be surprised if this turned out to be any faster than the simple code the compiler has already generated for you. SIMD is more useful if you have arrays of data you want to operate on..

This particular case is not a good fit for SIMD (SSE or otherwise). SIMD really only works well when you have contiguous arrays that you can access sequentially and process heterogeneously.

However you can at least get rid of some of the redundant operations in the scalar code, eg repeatedly calculating i * i * i when i is invariant:

do {
    for (i = 0; i < maxi; i++) {
        int i3 = i * i * i;
        int j = nextj[i];
        int j3 = j * j * j;
        long sum = i3 + j3;
        while (sum <= p) {
            long x = sum & (psize - 1);
            int flag = table[x];
            if (flag <= guard) {
              table[x] = guard+1;
            } else if (flag == guard+1) {
              table[x] = guard+2;
              count++;
            }
            j++;
            j3 = j * j * j;
            sum = i3 + j3;
        }
        nextj[i] = j;
    }
    p += psize;
    guard += 3;
} while (p <= n);

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM