繁体   English   中英

在没有共享的情况下计算每个扭曲直方图 memory

[英]Compute per-warp histogram without shared memory

问题计算 warp 中各个线程持有的已排序数字序列的 per-warp 直方图。

例子:

lane: 0123456789...          31
val:  222244455777799999 ..

结果必须由 N 个较低的线程保存在 warp 中(其中 N 是唯一数字的数量),例如:

lane 0: val=2, num=4 (2 occurs 4 times)
lane 1: val=4, num=3 (4 occurs 3 times)
lane 2: val=5, num=2 ...
lane 3: val=7, num=4
lane 4: val=9, num=5
...

请注意,基本上不需要对“val”序列进行排序:只需要将相等的数字组合在一起,即:99955555773333333 ...

可能的解决方案这可以通过 shuffle 内在函数非常有效地完成,尽管我的问题是是否可以在使用共享 memory 的情况下做到这一点(我的意思是共享 memory 是一种稀缺资源,我在其他地方需要它)?

为简单起见,我仅对单个 warp 执行此代码(以便 printf 正常工作):

__device__ __inline__ void sorted_seq_histogram()
{
    uint32_t tid = threadIdx.x, lane = tid % 32;
    uint32_t val = (lane + 117)* 23 / 97; // sorted sequence of values to be reduced

    printf("%d: val = %d\n", lane, val);
    uint32_t num = 1;

    uint32_t allmsk = 0xffffffffu, shfl_c = 31;
    for(int i = 1; i <= 16; i *= 2) {

#if 1
        uint32_t xval = __shfl_down_sync(allmsk, val, i),
                 xnum = __shfl_down_sync(allmsk, num, i);
        if(lane + i < 32) {
            if(val == xval)
                num += xnum;
        }
#else  // this is a (hopefully) optimized version of the code above
        asm(R"({
          .reg .u32 r0,r1;
          .reg .pred p;
          shfl.sync.down.b32 r0|p, %1, %2, %3, %4;
          shfl.sync.down.b32 r1|p, %0, %2, %3, %4;
          @p setp.eq.s32 p, %1, r0;
          @p add.u32 r1, r1, %0;
          @p mov.u32 %0, r1;
        })"
        : "+r"(num) : "r"(val), "r"(i), "r"(shfl_c), "r"(allmsk));
#endif
    }
    // shfl.sync wraps around: so thread 0 gets the value of thread 31
    bool leader = val != __shfl_sync(allmsk, val, lane - 1);
    auto OK = __ballot_sync(allmsk, leader); // find delimiter threads
    auto total = __popc(OK); // the total number of unique numbers found

    auto lanelt = (1 << lane) - 1;
    auto idx = __popc(OK & lanelt);

    printf("%d: val = %d; num = %d; total: %d; idx = %d; leader: %d\n", lane, val, num, total, idx, leader);

    __shared__ uint32_t sh[64];
    if(leader) {   // here we need shared memory :(
        sh[idx] = val;
        sh[idx + 32] = num;
    }
    __syncthreads();

    if(lane < total) {
        val = sh[lane], num = sh[lane + 32];
    } else {
        val = 0xDEADBABE, num = 0;
    }
    printf("%d: final val = %d; num = %d\n", lane, val, num);
}

这是我的 GPU output:

0: val = 27
1: val = 27
2: val = 28
3: val = 28
4: val = 28
5: val = 28
6: val = 29
7: val = 29
8: val = 29
9: val = 29
10: val = 30
11: val = 30
12: val = 30
13: val = 30
14: val = 31
15: val = 31
16: val = 31
17: val = 31
18: val = 32
19: val = 32
20: val = 32
21: val = 32
22: val = 32
23: val = 33
24: val = 33
25: val = 33
26: val = 33
27: val = 34
28: val = 34
29: val = 34
30: val = 34
31: val = 35
0: val = 27; num = 2; total: 9; idx = 0; leader: 1
1: val = 27; num = 1; total: 9; idx = 1; leader: 0
2: val = 28; num = 4; total: 9; idx = 1; leader: 1
3: val = 28; num = 3; total: 9; idx = 2; leader: 0
4: val = 28; num = 2; total: 9; idx = 2; leader: 0
5: val = 28; num = 1; total: 9; idx = 2; leader: 0
6: val = 29; num = 4; total: 9; idx = 2; leader: 1
7: val = 29; num = 3; total: 9; idx = 3; leader: 0
8: val = 29; num = 2; total: 9; idx = 3; leader: 0
9: val = 29; num = 1; total: 9; idx = 3; leader: 0
10: val = 30; num = 4; total: 9; idx = 3; leader: 1
11: val = 30; num = 3; total: 9; idx = 4; leader: 0
12: val = 30; num = 2; total: 9; idx = 4; leader: 0
13: val = 30; num = 1; total: 9; idx = 4; leader: 0
14: val = 31; num = 4; total: 9; idx = 4; leader: 1
15: val = 31; num = 3; total: 9; idx = 5; leader: 0
16: val = 31; num = 2; total: 9; idx = 5; leader: 0
17: val = 31; num = 1; total: 9; idx = 5; leader: 0
18: val = 32; num = 5; total: 9; idx = 5; leader: 1
19: val = 32; num = 4; total: 9; idx = 6; leader: 0
20: val = 32; num = 3; total: 9; idx = 6; leader: 0
21: val = 32; num = 2; total: 9; idx = 6; leader: 0
22: val = 32; num = 1; total: 9; idx = 6; leader: 0
23: val = 33; num = 4; total: 9; idx = 6; leader: 1
24: val = 33; num = 3; total: 9; idx = 7; leader: 0
25: val = 33; num = 2; total: 9; idx = 7; leader: 0
26: val = 33; num = 1; total: 9; idx = 7; leader: 0
27: val = 34; num = 4; total: 9; idx = 7; leader: 1
28: val = 34; num = 3; total: 9; idx = 8; leader: 0
29: val = 34; num = 2; total: 9; idx = 8; leader: 0
30: val = 34; num = 1; total: 9; idx = 8; leader: 0
31: val = 35; num = 1; total: 9; idx = 8; leader: 1
0: final val = 27; num = 2
1: final val = 28; num = 4
2: final val = 29; num = 4
3: final val = 30; num = 4
4: final val = 31; num = 4
5: final val = 32; num = 5
6: final val = 33; num = 4
7: final val = 34; num = 4
8: final val = 35; num = 1
9: final val = -559039810; num = 0
10: final val = -559039810; num = 0
11: final val = -559039810; num = 0
12: final val = -559039810; num = 0
13: final val = -559039810; num = 0
14: final val = -559039810; num = 0
15: final val = -559039810; num = 0
16: final val = -559039810; num = 0
17: final val = -559039810; num = 0
18: final val = -559039810; num = 0
19: final val = -559039810; num = 0
20: final val = -559039810; num = 0
21: final val = -559039810; num = 0
22: final val = -559039810; num = 0
23: final val = -559039810; num = 0
24: final val = -559039810; num = 0
25: final val = -559039810; num = 0
26: final val = -559039810; num = 0
27: final val = -559039810; num = 0
28: final val = -559039810; num = 0
29: final val = -559039810; num = 0
30: final val = -559039810; num = 0
31: final val = -559039810; num = 0

问题是否可以在不使用共享 memory 的情况下执行此操作? 不知何故,我无法用所有这些令人费解的洗牌内在函数来解决这个问题。

可以找出每个线程需要洗牌的通道,然后只使用__shfl_sync 唯一的问题/烦恼是我不知道没有循环就可以做到这一点。

所需的操作是找到OK中第n个设置位的“索引”,其中n是线程的通道。 SO 问题给定一个二进制数,如何在 O(1) 时间内找到从右数第 n 个设置位? 是关于这个问题,但它的答案只显示迭代解决方案。 然而,由于该问题与任何编程语言或内在函数无关,因此可能有某种方法可以巧妙地使用 integer 内在函数。

无论哪种方式,以下对我有用:

    // ... second printf
    auto src = lane;
    auto cnt = -1;
    for (int i = 0; i < warpSize; ++i) {
        if ((OK >> i) & 0x1 == 0x1) {
            ++cnt;
            if (cnt == lane) {
                src = i;
                break;
            }
        }
    }
    val = __shfl_sync(allmsk, val, src);
    num = __shfl_sync(allmsk, num, src);
    if (lane >= total) {
        val = 0xDEADBABE;
        num = 0;
    }
    // third printf ...

我不知道它在性能方面如何比较(应该在没有打印语句的情况下进行测量)。

我想我找到了解决方案:正如 paleonix 也指出的那样,问题是我们需要计算第 N 个位集。

实际上有一个名为fns.b32的 PTX 内在函数非常有趣,它就是这样做的。 然而,在我的 SM30 架构上,当我运行反汇编程序时,它会映射到一些疯狂的东西。

无论如何,我们还有popcount上的快速弹出计数内在函数,可用于计算对数时间内第 N 位集合的 position。 下面是完整的代码,现在根本不需要共享 memory:

__device__ __inline__ void sorted_seq_histogram()
{
    uint32_t tid = threadIdx.x, lane = tid % 32;
    uint32_t val = (lane + 117)* 23 / 97; // sorted sequence of values to be reduced

    printf("%d: val = %d\n", lane, val);
    uint32_t num = 1;

    const uint32_t allmsk = 0xffffffffu, shfl_c = 31;

    // shfl.sync wraps around: so thread 0 gets the value of thread 31
    bool leader = val != __shfl_sync(allmsk, val, lane - 1);
    auto OK = __ballot_sync(allmsk, leader); // find delimiter threads
    auto total = __popc(OK); // the total number of unique numbers found
    uint32_t pos = 0, N = lane+1; // each thread searches Nth bit set in 'OK' (1-indexed)

    for(int i = 1, j = 16; i <= 16; i *= 2, j /= 2) {

        uint32_t mval = OK & ((1 << j)-1); // here we compute the Nth bit set
        auto dif = (int)(N - __popc(mval));
        if(dif <= 0) {
            OK = mval;
        } else {
            N = dif, pos += j, OK >>= j;
        }
#if 1
        uint32_t xval = __shfl_down_sync(allmsk, val, i),
                 xnum = __shfl_down_sync(allmsk, num, i);
        if(lane + i < 32) {
            if(val == xval)
                num += xnum;
        }
#else  // this is a (hopefully) optimized version of the code above
        asm(R"({
          .reg .u32 r0,r1;
          .reg .pred p;
          shfl.sync.down.b32 r0|p, %1, %2, %3, %4;
          shfl.sync.down.b32 r1|p, %0, %2, %3, %4;
          @p setp.eq.s32 p, %1, r0;
          @p add.u32 r1, r1, %0;
          @p mov.u32 %0, r1;
        })"
        : "+r"(num) : "r"(val), "r"(i), "r"(shfl_c), "r"(allmsk));
#endif
    }
    num = __shfl_sync(allmsk, num, pos); // read from pos-th thread
    val = __shfl_sync(allmsk, val, pos); // read from pos-th thread
    if(lane >= total) {
        num = 0xDEADBABE;
    }

    printf("%d: final val = %d; num = %d; \n", lane, val, num);
}

和程序 output:

0: val = 27
1: val = 27
2: val = 28
3: val = 28
4: val = 28
5: val = 28
6: val = 29
7: val = 29
8: val = 29
9: val = 29
10: val = 30
11: val = 30
12: val = 30
13: val = 30
14: val = 31
15: val = 31
16: val = 31
17: val = 31
18: val = 32
19: val = 32
20: val = 32
21: val = 32
22: val = 32
23: val = 33
24: val = 33
25: val = 33
26: val = 33
27: val = 34
28: val = 34
29: val = 34
30: val = 34
31: val = 35
0: final val = 27; num = 2;
1: final val = 28; num = 4;
2: final val = 29; num = 4;
3: final val = 30; num = 4;
4: final val = 31; num = 4;
5: final val = 32; num = 5;
6: final val = 33; num = 4;
7: final val = 34; num = 4;
8: final val = 35; num = 1;
9: final val = 35; num = -559039810;
10: final val = 35; num = -559039810;
11: final val = 35; num = -559039810;
12: final val = 35; num = -559039810;
13: final val = 35; num = -559039810;
14: final val = 35; num = -559039810;
15: final val = 35; num = -559039810;
16: final val = 35; num = -559039810;
17: final val = 35; num = -559039810;
18: final val = 35; num = -559039810;
19: final val = 35; num = -559039810;
20: final val = 35; num = -559039810;
21: final val = 35; num = -559039810;
22: final val = 35; num = -559039810;
23: final val = 35; num = -559039810;
24: final val = 35; num = -559039810;
25: final val = 35; num = -559039810;
26: final val = 35; num = -559039810;
27: final val = 35; num = -559039810;
28: final val = 35; num = -559039810;
29: final val = 35; num = -559039810;
30: final val = 35; num = -559039810;
31: final val = 35; num = -559039810;

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM