繁体   English   中英

如何使用汇编优化这个 8 位位置 popcount?

[英]How to optimise this 8-bit positional popcount using assembly?

这篇文章与_mm_add_epi32 的 Golang 汇编实现有关,它在两个[8]int32列表中添加成对的元素,并返回更新后的第一个。

根据 pprof 配置文件,我发现传递[8]int32是昂贵的,所以我认为传递列表的指针要便宜得多并且 bech 结果验证了这一点。 这是 go 版本:

func __mm_add_epi32_inplace_purego(x, y *[8]int32) {
    (*x)[0] += (*y)[0]
    (*x)[1] += (*y)[1]
    (*x)[2] += (*y)[2]
    (*x)[3] += (*y)[3]
    (*x)[4] += (*y)[4]
    (*x)[5] += (*y)[5]
    (*x)[6] += (*y)[6]
    (*x)[7] += (*y)[7]
}

这个 function 在两级循环中被调用。

该算法计算一个字节数组的position 人口计数

感谢@fuz 的建议,我知道用汇编编写整个算法是最好的选择并且很有意义,但这超出了我的能力,因为我从未学习过汇编编程。

但是,使用汇编优化内循环应该很容易:

counts := make([][8]int32, numRowBytes)

for i, b = range byteSlice {
    if b == 0 {                  // more than half of elements in byteSlice is 0.
        continue
    }
    expand = _expand_byte[b]
    __mm_add_epi32_inplace_purego(&counts[i], expand)
}

// expands a byte into its bits
var _expand_byte = [256]*[8]int32{
    &[8]int32{0, 0, 0, 0, 0, 0, 0, 0},
    &[8]int32{0, 0, 0, 0, 0, 0, 0, 1},
    &[8]int32{0, 0, 0, 0, 0, 0, 1, 0},
    &[8]int32{0, 0, 0, 0, 0, 0, 1, 1},
    &[8]int32{0, 0, 0, 0, 0, 1, 0, 0},
    ...
}

你能帮忙写一个__mm_add_epi32_inplace_purego的汇编版本(这对我来说已经足够了),甚至是整个循环吗? 先感谢您。

您要执行的操作称为字节位置填充计数 这是机器学习中使用的一个众所周知的操作,并且已经对快速算法进行了一些研究来解决这个问题。

不幸的是,这些算法的实现相当复杂。 出于这个原因,我开发了一种自定义算法,该算法实现起来要简单得多,但其性能大约只有另一种方法的一半。 但是,在测得的 10 GB/s 下,与您之前的速度相比,它应该仍然是一个不错的改进。

该算法的思想是使用vpmovmskb从 32 字节组中收集相应的位,然后获取标量人口计数,然后将其添加到相应的计数器中。 这允许依赖链很短,并且达到一致的 IPC 3。

请注意,与您的算法相比,我的代码翻转了位的顺序。 如果需要,您可以通过编辑汇编代码访问的数组元素的counts来更改此设置。 然而,为了未来读者的利益,我想保留这段代码更常见的约定,即最低有效位被视为位 0。

源代码

完整的源代码可以在 github找到。 作者同时将这个算法思想发展成一个可移植的库,可以像这样使用:

import "github.com/clausecker/pospop"

var counts [8]int
pospop.Count8(counts, buf)  // add positional popcounts for buf to counts

该算法以两种变体形式提供,并已在处理器标识为“Intel(R) Xeon(R) W-2133 CPU @ 3.60GHz”的机器上进行了测试。

位置填充一次计数 32 个字节。

计数器保存在通用寄存器中以获得最佳性能。 Memory 已提前预取,以获得更好的流式传输行为。 使用非常简单的SHRL / ADCL组合处理标量尾部。 实现了高达 11 GB/s 的性能。

#include "textflag.h"

// func PospopcntReg(counts *[8]int32, buf []byte)
TEXT ·PospopcntReg(SB),NOSPLIT,$0-32
    MOVQ counts+0(FP), DI
    MOVQ buf_base+8(FP), SI     // SI = &buf[0]
    MOVQ buf_len+16(FP), CX     // CX = len(buf)

    // load counts into register R8--R15
    MOVL 4*0(DI), R8
    MOVL 4*1(DI), R9
    MOVL 4*2(DI), R10
    MOVL 4*3(DI), R11
    MOVL 4*4(DI), R12
    MOVL 4*5(DI), R13
    MOVL 4*6(DI), R14
    MOVL 4*7(DI), R15

    SUBQ $32, CX            // pre-subtract 32 bit from CX
    JL scalar

vector: VMOVDQU (SI), Y0        // load 32 bytes from buf
    PREFETCHT0 384(SI)      // prefetch some data
    ADDQ $32, SI            // advance SI past them

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R15            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R14            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R13            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R12            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R11            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R10            // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R9         // add to counter
    VPADDD Y0, Y0, Y0       // shift Y0 left by one place

    VPMOVMSKB Y0, AX        // move MSB of Y0 bytes to AX
    POPCNTL AX, AX          // count population of AX
    ADDL AX, R8         // add to counter

    SUBQ $32, CX
    JGE vector          // repeat as long as bytes are left

scalar: ADDQ $32, CX            // undo last subtraction
    JE done             // if CX=0, there's nothing left

loop:   MOVBLZX (SI), AX        // load a byte from buf
    INCQ SI             // advance past it

    SHRL $1, AX         // CF=LSB, shift byte to the right
    ADCL $0, R8         // add CF to R8

    SHRL $1, AX
    ADCL $0, R9         // add CF to R9

    SHRL $1, AX
    ADCL $0, R10            // add CF to R10

    SHRL $1, AX
    ADCL $0, R11            // add CF to R11

    SHRL $1, AX
    ADCL $0, R12            // add CF to R12

    SHRL $1, AX
    ADCL $0, R13            // add CF to R13

    SHRL $1, AX
    ADCL $0, R14            // add CF to R14

    SHRL $1, AX
    ADCL $0, R15            // add CF to R15

    DECQ CX             // mark this byte as done
    JNE loop            // and proceed if any bytes are left

    // write R8--R15 back to counts
done:   MOVL R8, 4*0(DI)
    MOVL R9, 4*1(DI)
    MOVL R10, 4*2(DI)
    MOVL R11, 4*3(DI)
    MOVL R12, 4*4(DI)
    MOVL R13, 4*5(DI)
    MOVL R14, 4*6(DI)
    MOVL R15, 4*7(DI)

    VZEROUPPER          // restore SSE-compatibility
    RET

使用 CSA 一次计算 96 个字节的位置填充

此变体执行上述所有优化,但预先使用单个 CSA 步骤将 96 字节减少到 64 字节。 正如预期的那样,这将性能提高了大约 30% 并达到了 16 GB/s。

#include "textflag.h"

// func PospopcntRegCSA(counts *[8]int32, buf []byte)
TEXT ·PospopcntRegCSA(SB),NOSPLIT,$0-32
    MOVQ counts+0(FP), DI
    MOVQ buf_base+8(FP), SI     // SI = &buf[0]
    MOVQ buf_len+16(FP), CX     // CX = len(buf)

    // load counts into register R8--R15
    MOVL 4*0(DI), R8
    MOVL 4*1(DI), R9
    MOVL 4*2(DI), R10
    MOVL 4*3(DI), R11
    MOVL 4*4(DI), R12
    MOVL 4*5(DI), R13
    MOVL 4*6(DI), R14
    MOVL 4*7(DI), R15

    SUBQ $96, CX            // pre-subtract 32 bit from CX
    JL scalar

vector: VMOVDQU (SI), Y0        // load 96 bytes from buf into Y0--Y2
    VMOVDQU 32(SI), Y1
    VMOVDQU 64(SI), Y2
    ADDQ $96, SI            // advance SI past them
    PREFETCHT0 320(SI)
    PREFETCHT0 384(SI)

    VPXOR Y0, Y1, Y3        // first adder: sum
    VPAND Y0, Y1, Y0        // first adder: carry out
    VPAND Y2, Y3, Y1        // second adder: carry out
    VPXOR Y2, Y3, Y2        // second adder: sum (full sum)
    VPOR Y0, Y1, Y0         // full adder: carry out

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R15

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R14

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R13

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R12

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R11

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R10

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    VPADDB Y0, Y0, Y0       // shift carry out bytes left
    VPADDB Y2, Y2, Y2       // shift sum bytes left
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R9

    VPMOVMSKB Y0, AX        // MSB of carry out bytes
    VPMOVMSKB Y2, DX        // MSB of sum bytes
    POPCNTL AX, AX          // carry bytes population count
    POPCNTL DX, DX          // sum bytes population count
    LEAL (DX)(AX*2), AX     // sum popcount plus 2x carry popcount
    ADDL AX, R8

    SUBQ $96, CX
    JGE vector          // repeat as long as bytes are left

scalar: ADDQ $96, CX            // undo last subtraction
    JE done             // if CX=0, there's nothing left

loop:   MOVBLZX (SI), AX        // load a byte from buf
    INCQ SI             // advance past it

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R8         // add it to R8

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R9         // add it to R9

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R10            // add it to R10

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R11            // add it to R11

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R12            // add it to R12

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R13            // add it to R13

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R14            // add it to R14

    SHRL $1, AX         // is bit 0 set?
    ADCL $0, R15            // add it to R15

    DECQ CX             // mark this byte as done
    JNE loop            // and proceed if any bytes are left

    // write R8--R15 back to counts
done:   MOVL R8, 4*0(DI)
    MOVL R9, 4*1(DI)
    MOVL R10, 4*2(DI)
    MOVL R11, 4*3(DI)
    MOVL R12, 4*4(DI)
    MOVL R13, 4*5(DI)
    MOVL R14, 4*6(DI)
    MOVL R15, 4*7(DI)

    VZEROUPPER          // restore SSE-compatibility
    RET

基准

以下是两种算法的基准测试和纯 Go 中的简单参考实现。可以在 github 存储库中找到完整的基准测试。

BenchmarkReference/10-12    12448764            80.9 ns/op   123.67 MB/s
BenchmarkReference/32-12     4357808           258 ns/op     124.25 MB/s
BenchmarkReference/1000-12            151173          7889 ns/op     126.76 MB/s
BenchmarkReference/2000-12             68959         15774 ns/op     126.79 MB/s
BenchmarkReference/4000-12             36481         31619 ns/op     126.51 MB/s
BenchmarkReference/10000-12            14804         78917 ns/op     126.72 MB/s
BenchmarkReference/100000-12            1540        789450 ns/op     126.67 MB/s
BenchmarkReference/10000000-12            14      77782267 ns/op     128.56 MB/s
BenchmarkReference/1000000000-12           1    7781360044 ns/op     128.51 MB/s
BenchmarkReg/10-12                  49255107            24.5 ns/op   407.42 MB/s
BenchmarkReg/32-12                  186935192            6.40 ns/op 4998.53 MB/s
BenchmarkReg/1000-12                 8778610           115 ns/op    8677.33 MB/s
BenchmarkReg/2000-12                 5358495           208 ns/op    9635.30 MB/s
BenchmarkReg/4000-12                 3385945           357 ns/op    11200.23 MB/s
BenchmarkReg/10000-12                1298670           901 ns/op    11099.24 MB/s
BenchmarkReg/100000-12                115629          8662 ns/op    11544.98 MB/s
BenchmarkReg/10000000-12                1270        916817 ns/op    10907.30 MB/s
BenchmarkReg/1000000000-12                12      93609392 ns/op    10682.69 MB/s
BenchmarkRegCSA/10-12               48337226            23.9 ns/op   417.92 MB/s
BenchmarkRegCSA/32-12               12843939            80.2 ns/op   398.86 MB/s
BenchmarkRegCSA/1000-12              7175629           150 ns/op    6655.70 MB/s
BenchmarkRegCSA/2000-12              3988408           295 ns/op    6776.20 MB/s
BenchmarkRegCSA/4000-12              3016693           382 ns/op    10467.41 MB/s
BenchmarkRegCSA/10000-12             1810195           642 ns/op    15575.65 MB/s
BenchmarkRegCSA/100000-12             191974          6229 ns/op    16053.40 MB/s
BenchmarkRegCSA/10000000-12             1622        698856 ns/op    14309.10 MB/s
BenchmarkRegCSA/1000000000-12             16      68540642 ns/op    14589.88 MB/s

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM