[英]How does this algorithm to count the number of set bits in a 32-bit integer work?
int SWAR(unsigned int i)
{
i = i - ((i >> 1) & 0x55555555);
i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
return (((i + (i >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
}
我看过这段代码,它计算 32 位整数中的位数等于1
,我注意到它的性能优于__builtin_popcount
但我无法理解它的工作方式。
有人可以详细解释这段代码是如何工作的吗?
好,让我们一行一行的看一遍代码:
i = i - ((i >> 1) & 0x55555555);
首先,常量0x55555555
的意义在于,使用 Java/GCC 风格的二进制文字表示法编写),
0x55555555 = 0b01010101010101010101010101010101
也就是说,它的所有奇数位(将最低位计算为位 1 = 奇数)都是1
,所有偶数位都是0
。
表达式((i >> 1) & 0x55555555)
因此将i
的位右移一,然后将所有偶数位设置为零。 (等效地,我们可以首先使用& 0xAAAAAAAA
将i
所有奇数位设置为零,然后将结果右移一位。)为方便起见,我们称此中间值j
。
当我们从原来的i
减去这个j
会发生什么? 好吧,让我们看看如果i
只有两个位会发生什么:
i j i - j
----------------------------------
0 = 0b00 0 = 0b00 0 = 0b00
1 = 0b01 0 = 0b00 1 = 0b01
2 = 0b10 1 = 0b01 1 = 0b01
3 = 0b11 1 = 0b01 2 = 0b10
嘿! 我们已经设法计算出两位数的位数!
好的,但是如果i
设置了两个以上的位呢? 事实上,很容易检查i - j
的最低两位是否仍由上表给出,第三和第四位、第五和第六位也是如此,依此类推。 特别是:
尽管>> 1
,最低的两位i - j
不受第三或更高位的i
,因为他们会来屏蔽掉j
由& 0x55555555
; 和
由于j
的最低两位永远不会有比i
更大的数值,减法永远不会从i
的第三位借用:因此, i
的最低两位也不会影响i
的第三位或更高位i - j
。
事实上,通过重复相同的参数,我们可以看到这一行的计算实际上将上表应用于i
in parallel中的 16 个两位块中的每一个。 也就是说,在执行这条线之后,新的价值的最低两位i
现在包含的原始值的对应位中设置的位数i
,等会下两个位,依此类推。
i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
与第一行相比,这一行相当简单。 首先,请注意
0x33333333 = 0b00110011001100110011001100110011
因此, i & 0x33333333
获取上面计算的两位计数并每秒丢弃其中一个,而(i >> 2) & 0x33333333
在将i
右移两位后做同样的事情。 然后我们将结果相加。
因此,实际上,这一行的作用是获取在前一行计算的原始输入的最低两位和第二低两位的位数,并将它们加在一起以给出输入的最低四位的位数。输入。 同样,它对输入的所有8 个四位块(= 十六进制数字)并行执行此操作。
return (((i + (i >> 4)) & 0x0F0F0F0F) * 0x01010101) >> 24;
好的,这是怎么回事?
嗯,首先, (i + (i >> 4)) & 0x0F0F0F0F
与上一行完全相同,除了它将相邻的四位比特数加在一起以给出每个八位块(即字节) 的输入。 (这里,与前一行不同,我们可以将&
移到加法之外,因为我们知道八位的位数永远不会超过 8,因此可以放入四位而不会溢出。)
现在我们有一个 32 位数字,由四个 8 位字节组成,每个字节保存原始输入的那个字节中的 1 位数字。 (我们称这些字节为A
、 B
、 C
和D
。)那么当我们将此值(我们称其为k
)乘以0x01010101
什么?
好吧,由于0x01010101 = (1 << 24) + (1 << 16) + (1 << 8) + 1
,我们有:
k * 0x01010101 = (k << 24) + (k << 16) + (k << 8) + k
因此,结果的最高字节最终是以下各项的总和:
k
项,加上k << 8
项,加上k << 16
项,加上k << 24
项,第四个和最低字节的值。(通常,低字节也可能有进位,但由于我们知道每个字节的值最多为 8,我们知道加法永远不会溢出并创建进位。)
也就是说, k * 0x01010101
的最高字节最终是输入的所有字节的比特数的总和,即32 位输入数的总比特数。 最后的>> 24
然后简单地将该值从最高字节向下移动到最低字节。
附言。 这段代码可以很容易地扩展到 64 位整数,只需将0x01010101
更改为0x0101010101010101
并将>> 24
更改为>> 56
。 事实上,同样的方法甚至适用于 128 位整数; 然而,256 位需要添加一个额外的移位/添加/掩码步骤,因为数字 256 不再完全适合 8 位字节。
我更喜欢这个,它更容易理解。
x = (x & 0x55555555) + ((x >> 1) & 0x55555555);
x = (x & 0x33333333) + ((x >> 2) & 0x33333333);
x = (x & 0x0f0f0f0f) + ((x >> 4) & 0x0f0f0f0f);
x = (x & 0x00ff00ff) + ((x >> 8) & 0x00ff00ff);
x = (x & 0x0000ffff) + ((x >> 16) &0x0000ffff);
这是对Ilamari 的回答的评论。 由于格式问题,我将其作为答案:
i = i - ((i >> 1) & 0x55555555); // (1)
此行源自此更易于理解的行:
i = (i & 0x55555555) + ((i >> 1) & 0x55555555); // (2)
如果我们打电话
i = input value
j0 = i & 0x55555555
j1 = (i >> 1) & 0x55555555
k = output value
我们可以重写(1)和(2),使解释更清楚:
k = i - j1; // (3)
k = j0 + j1; // (4)
我们想证明(3)可以从(4)导出。
i
可以写成其偶数位和奇数位的相加(将最低位计算为位 1 = 奇数):
i = iodd + ieven =
= (i & 0x55555555) + (i & 0xAAAAAAAA) =
= (i & modd) + (i & meven)
由于meven
掩码清除了i
的最后一位,所以最后一个等式可以这样写:
i = (i & modd) + ((i >> 1) & modd) << 1 =
= j0 + 2*j1
那是:
j0 = i - 2*j1 (5)
最后,将(5)代入(4)我们得到(3):
k = j0 + j1 = i - 2*j1 + j1 = i - j1
这是对yeer的回答的解释:
int SWAR(unsigned int i) {
i = (i & 0x55555555) + ((i >> 1) & 0x55555555); // A
i = (i & 0x33333333) + ((i >> 2) & 0x33333333); // B
i = (i & 0x0f0f0f0f) + ((i >> 4) & 0x0f0f0f0f); // C
i = (i & 0x00ff00ff) + ((i >> 8) & 0x00ff00ff); // D
i = (i & 0x0000ffff) + ((i >> 16) &0x0000ffff); // E
return i;
}
让我们以 A 行作为我解释的基础。
i = (i & 0x55555555) + ((i >> 1) & 0x55555555)
让我们将上面的表达式重命名如下:
i = (i & mask) + ((i >> 1) & mask)
= A1 + A2
首先,不要将i
视为 32 位,而是将其视为 16 组的数组,每组 2 位。 A1
是大小为 16 的计数数组,每组包含i
中相应组最右边的1
的计数:
i = yx yx yx yx yx yx yx yx yx yx yx yx yx yx yx yx
mask = 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01 01
i & mask = 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x 0x
同样, A2
正在“计算” i
中每个组的最左边位。 请注意,我可以将A2 = (i >> 1) & mask
重写为A2 = (i & mask2) >> 1
:
i = yx yx yx yx yx yx yx yx yx yx yx yx yx yx yx yx
mask2 = 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10 10
(i & mask2) = y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0 y0
(i & mask2) >> 1 = 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y 0y
(注意mask2 = 0xaaaaaaaa
)
因此, A1 + A2
将A1
数组和A2
数组的计数相加,得到一个 16 组的数组,每组现在包含每组中的位数。
移动到 B 行,我们可以将该行重命名如下:
i = (i & 0x33333333) + ((i >> 2) & 0x33333333)
= (i & mask) + ((i >> 2) & mask)
= B1 + B2
B1 + B2
遵循与之前的A1 + A2
相同的“形式”。 不再将i
视为 16 组 2 位,而是 8 组 4 位。 所以和之前类似, B1 + B2
将B1
和B2
的计数加在一起,其中B1
是该组右侧1
的计数, B2
是该组左侧的计数。 因此, B1 + B2
是每组中的位数。
C 到 E 行现在变得更容易理解了:
int SWAR(unsigned int i) {
// A: 16 groups of 2 bits, each group contains number of 1s in that group.
i = (i & 0x55555555) + ((i >> 1) & 0x55555555);
// B: 8 groups of 4 bits, each group contains number of 1s in that group.
i = (i & 0x33333333) + ((i >> 2) & 0x33333333);
// C: 4 groups of 8 bits, each group contains number of 1s in that group.
i = (i & 0x0f0f0f0f) + ((i >> 4) & 0x0f0f0f0f);
// D: 2 groups of 16 bits, each group contains number of 1s in that group.
i = (i & 0x00ff00ff) + ((i >> 8) & 0x00ff00ff);
// E: 1 group of 32 bits, containing the number of 1s in that group.
i = (i & 0x0000ffff) + ((i >> 16) &0x0000ffff);
return i;
}
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.