简体   繁体   中英

Bit manipulation: keeping the common part at the left of the last different bit

Consider two numbers written in binary (MSB at left):

X = x7 x6 x5 x4 x3 x2 x1 x0

and

Y = y7 y6 y5 y4 y3 y2 y1 y0

These numbers can have an arbitrary number of bits but both are of the same type. Now consider that x7 == y7 , x6 == y6 , x5 == y5 , but x4 != y4 .

How to compute:

Z = x7 x6 x5 0 0 0 0 0

or in other words, how to compute efficiently a number that keeps the common part at the left of the last different bit ?

template <typename T>
inline T f(const T x, const T y) 
{
    // Something here
}

For example, for:

x = 10100101
y = 10110010

it should return

z = 10100000

Note: it is for supercomputing purpose and this operation will be executed hundreds of billion times so scanning the bits one by one should be avoided...

My answer is based on @JerryCoffin's one.

int d = x ^ y;
d = d | (d >> 1);
d = d | (d >> 2);
d = d | (d >> 4);
d = d | (d >> 8);
d = d | (d >> 16);
int z = x & (~d);

Part of this problem shows up semi-regularly in bit-manipulation: "parallel suffix with OR", or "prefix" (that is, depending on who you listen to, the low bits are either called a suffix or a prefix). Obviously once you have a way to do that, it's trivial to extend it to what you want (as shown in the other answers).

Anyway, the obvious way is:

x |= x >> 1
x |= x >> 2
x |= x >> 4
x |= x >> 8
x |= x >> 16

But you're probably not constrained to simple operators.

For Haswell, the fastest way I found was:

lzcnt rax, rax     ; number of leading zeroes, sets carry if rax=0
mov edx, 64
sub edx, eax
mov rax, -1
bzhi rax, rax, rdx ; reset the bits in rax starting at position rdx

Other contenders were:

mov rdx, -1
bsr rax, rax       ; position of the highest set bit, set Z flag if no bit
cmovz rdx, rax     ; set rdx=rax iff Z flag is set
xor eax, 63
shrx rax, rdx, rax ; rax = rdx >> rax

And

lzcnt rax, rax
sbb rdx, rdx       ; rdx -= rdx + carry (so 0 if no carry, -1 if carry)
not rdx
shrx rax, rdx, rax

But they were not as fast.

I've also considered

lzcnt rax, rax
mov rax, [table+rax*8]

But it's hard to compare it fairly, since it's the only one that spends cache space, which has non-local effects.

Benchmarking various ways to do this led to this question about some curious behaviour of lzcnt .

They all rely on some fast way to determine the position of the highest set bit, which you could do with a cast to float and exponent extraction if you really had to, so probably most platforms can use something like it.

A shift that gives zero if the shift-count is equal to or bigger than the operand size would be very nice to solve this problem. x86 doesn't have one, but maybe your platform does.

If you had a fast bit-reversal instruction, you could do something like: (this isn't intended to be ARM asm)

rbit r0, r0
neg r1, r0
or r0, r1, r0
rbit r0, r0

Comparing several algorithms leads to this ranking:

Having an inner loop of 1 or 10 in the test below:

  1. Utilizing a built in bit scan function.
  2. Filling least significant bits with or and shift (The function of @Egor Skriptunoff).
  3. Involving a lookup table.
  4. Scanning the most significant bit (The second function of @Tomas).

InnerLoops = 10:

Timing 1: 0.101284
Timing 2: 0.108845
Timing 3: 0.102526
Timing 4: 0.191911

An inner loop of 100 or greater:

  1. Utilizing a built in bit scan function.
  2. Involving a lookup table.
  3. Filling least significant bits with or and shift (The function of @Egor Skriptunoff).
  4. Scanning the most significant bit (The second function of @Tomas).

InnerLoops = 100:

Timing 1: 0.441786
Timing 2: 0.507651
Timing 3: 0.548328
Timing 4: 0.593668

The test:

#include <algorithm>
#include <chrono>
#include <limits>
#include <iostream>
#include <iomanip>

// Functions
// =========

inline unsigned function1(unsigned  a, unsigned b)
{
    a ^= b;
    if(a) {
        int n = __builtin_clz (a);
        a = (~0u) >> n;
    }
    return ~a & b;
}

typedef std::uint8_t byte;
static byte msb_table[256] = {
    0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
    6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
    7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
    7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
    8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
    8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
    8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
    8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8,
};

inline unsigned function2(unsigned a, unsigned  b)
{
    a ^= b;
    if(a) {
        unsigned n = 0;
        if(a >> 24) n = msb_table[byte(a >> 24)] + 24;
        else if(a >> 16) n = msb_table[byte(a >> 16)] + 16;
        else if(a >> 8) n = msb_table[byte(a >> 8)] + 8;
        else n = msb_table[byte(a)];
        a = (~0u) >> (32-n);
    }
    return ~a & b;
}

inline unsigned function3(unsigned  a, unsigned  b)
{
    unsigned d = a ^ b;
    d = d | (d >> 1);
    d = d | (d >> 2);
    d = d | (d >> 4);
    d = d | (d >> 8);
    d = d | (d >> 16);
    return a & (~d);;
}

inline unsigned function4(unsigned  a, unsigned  b)
{
    const unsigned maxbit = 1u << (std::numeric_limits<unsigned>::digits - 1);
    unsigned msb = maxbit;
    a ^= b;
    while( ! (a & msb))
        msb >>= 1;
    if(msb == maxbit) return 0;
    else {
        msb <<= 1;
        msb  -= 1;
        return ~msb & b;
    }
}


// Test
// ====

inline double duration(
    std::chrono::system_clock::time_point start,
    std::chrono::system_clock::time_point end)
{
    return double((end - start).count())
        / std::chrono::system_clock::period::den;
}

int main() {
    typedef unsigned (*Function)(unsigned , unsigned);
    Function fn[] = {
        function1,
        function2,
        function3,
        function4,
    };
    const unsigned N = sizeof(fn) / sizeof(fn[0]);
    std::chrono::system_clock::duration timing[N] = {};
    const unsigned OuterLoops = 1000000;
    const unsigned InnerLoops = 100;
    const unsigned Samples = OuterLoops * InnerLoops;
    unsigned* A = new unsigned[Samples];
    unsigned* B = new unsigned[Samples];
    for(unsigned i = 0; i < Samples; ++i) {
        A[i] = std::rand();
        B[i] = std::rand();
    }
    unsigned F[N];
    for(unsigned f = 0; f < N; ++f) F[f] = f;
    unsigned result[N];
    for(unsigned i = 0; i < OuterLoops; ++i) {
        std::random_shuffle(F, F + N);
        for(unsigned f = 0; f < N; ++f) {
            unsigned g = F[f];
            auto start = std::chrono::system_clock::now();
            for(unsigned j = 0; j < InnerLoops; ++j) {
                unsigned index = i + j;
                unsigned a = A[index];
                unsigned b = B[index];
                result[g] = fn[g](a, b);
            }
            auto end = std::chrono::system_clock::now();
            timing[g] += (end-start);
        }
        for(unsigned f = 1; f < N; ++f) {
            if(result[0] != result[f]) {
                std::cerr << "Different Results\n" << std::hex;
                for(unsigned g = 0; g < N; ++g)
                    std::cout << "Result " << g+1 << ": " << result[g] << '\n';
                exit(-1);
            }
        }
    }

    for(unsigned i = 0; i < N; ++i) {
        std::cout
            << "Timing " << i+1 << ": "
            << double(timing[i].count()) / std::chrono::system_clock::period::den
            << "\n";
    }
}

Compiler:

g++ 4.7.2

Hardware:

Intel® Core™ i3-2310M CPU @ 2.10GHz × 4 7.7 GiB

You may reduce it to much easier problem of finding the highest set bit (highest 1 ), which is actually the same as finding ceil(log 2 X).

unsigned int x, y, c, m;
int b;

c = x ^ y;          // xor : 00010111

// now it comes: b = number of highest set bit in c
// perhaps some special operation or instruction exists for that
b = -1;
while (c) {
    b++;
    c = c >> 1;
}                  // b == 4

m = (1 << (b + 1)) - 1;   // creates a mask: 00011111
return x & ~m;    // x AND NOT M
return y & ~m;    // should return the same result

In fact, if you can compute the ceil(log 2 c) easily, then just subtract 1 and you have m , without the need for computing b using the loop above.

If you don't have such functionality, simple optimized code, which uses just basic assembly level operations (bit shifts by one bit: <<=1 , >>=1 ) would look like this:

c = x ^ y;        // c == 00010111 (xor)
m = 1;
while (c) {
    m <<= 1; 
    c >>= 1;
}                 // m == 00100000
m--;              // m == 00011111 (mask)
return x & ~m;    // x AND NOT M

This can be compiled to a very fast code, mostly like one or two machine instructions per line.

It's a little ugly, but assuming 8-bit inputs, you can do something like this:

int x = 0xA5; // 1010 0101
int y = 0xB2; // 1011 0010
unsigned d = x ^ y;

int mask = ~(d | (d >> 1) | (d >> 2) | (d >> 3) | (d >> 4) | (d >> 5) | (d >> 6));

int z = x & mask;

We start by computing the exclusive-or of the numbers, which will give a 0 where they're equal, and a 1 where they're different. For your example, that gives:

00010111

We then shift that right and inclusive-or it with itself each of 7 possible bit positions:

00010111
00001011
00000101
00000010
00000001

That gives:

00011111

Which is 0's where the original numbers were equal, and 1's where they were different. We then invert that to get:

11100000

Then we and that with one of the original inputs (doesn't matter which) to get:

10100000

...exactly the result we wanted (and unlike a simple x & y , it'll also work for other values of x and y ).

Of course, this can be extended out to an arbitrary width, but if you were working with (say) 64-bit numbers, the d | (d>>1) | ... | (d>>63); d | (d>>1) | ... | (d>>63); would be a little on the long and clumsy side.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM