How can I create a __m128i
having the n
most significant bits set (in the entire vector)? I need this to mask portions of a buffer that are relevant for a computation. If possible, the solution should have no branches, but this seems hard to achieve
How can I do this ?
I'm adding this as a second answer and leaving the first answer for historical interest. It looks like you can do something more efficient with _mm_slli_epi64
:
#include <emmintrin.h>
#include <stdio.h>
__m128i bit_mask(int n)
{
__m128i v0 = _mm_set_epi64x(-1, -(n > 64)); // AND mask
__m128i v1 = _mm_set_epi64x(-(n > 64), 0); // OR mask
__m128i v2 = _mm_slli_epi64(_mm_set1_epi64x(-1), (128 - n) & 63);
v2 = _mm_and_si128(v2, v0);
v2 = _mm_or_si128(v2, v1);
return v2;
}
int main(int argc, char *argv[])
{
int n = 36;
if (argc > 1) n = atoi(argv[1]);
printf("bit_mask(%3d) = %02vx\n", n, bit_mask(n));
return 0;
}
Test:
$ gcc -Wall -msse2 sse_bit_mask.c
$ for n in 1 2 3 63 64 65 127 128 ; do ./a.out $n ; done
bit_mask( 1) = 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 80
bit_mask( 2) = 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 c0
bit_mask( 3) = 00 00 00 00 00 00 00 00 00 00 00 00 00 00 00 e0
bit_mask( 63) = 00 00 00 00 00 00 00 00 fe ff ff ff ff ff ff ff
bit_mask( 64) = 00 00 00 00 00 00 00 00 ff ff ff ff ff ff ff ff
bit_mask( 65) = 00 00 00 00 00 00 00 80 ff ff ff ff ff ff ff ff
bit_mask(127) = fe ff ff ff ff ff ff ff ff ff ff ff ff ff ff ff
bit_mask(128) = ff ff ff ff ff ff ff ff ff ff ff ff ff ff ff ff
You can use one of the methods from this question to generate a mask with the MS n bytes set to all ones. You would then just need to fix up any remaining bits when n is not a multiple of 8.
I suggest trying something like this:
- init vector A = all (8 bit) elements to the residual mask of n % 8 bits
- init vector B = mask of n / 8 bytes using one of the above-mentioned methods
- init vector C = mask of (n + 7) / 8 bytes using one of the above-mentioned methods
- result = A | B & C
So for example if n = 36:
A = f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0 f0
B = ff ff ff ff 00 00 00 00 00 00 00 00 00 00 00 00
C = ff ff ff ff ff 00 00 00 00 00 00 00 00 00 00 00
==> ff ff ff ff f0 00 00 00 00 00 00 00 00 00 00 00
This would be branchless, as required, but it's probably of the order of ~10 instructions. There may be a more efficient method but I would need to give this some more thought.
The next two solutions are an alternative to Paul R's answer . These solutions are of interest when the masks are needed in the context of a performance critical loop.
SSE2
__m128i bit_mask_v2(unsigned int n){ /* Create an __m128i vector with the n most significant bits set to 1 */
__m128i ones_hi = _mm_set_epi64x(-1,0); /* Binary vector of bits 1...1 and 0...0 */
__m128i ones_lo = _mm_set_epi64x(0,-1); /* Binary vector of bits 0...0 and 1...1 */
__m128i cnst64 = _mm_set1_epi64x(64);
__m128i cnst128 = _mm_set1_epi64x(128);
__m128i shift = _mm_cvtsi32_si128(n); /* Move n to SSE register */
__m128i shift_hi = _mm_subs_epu16(cnst64,shift); /* Subtract with saturation */
__m128i shift_lo = _mm_subs_epu16(cnst128,shift);
__m128i hi = _mm_sll_epi64(ones_hi,shift_hi); /* Shift the hi bits 64-n positions if 64-n>=0, else no shift */
__m128i lo = _mm_sll_epi64(ones_lo,shift_lo); /* Shift the lo bits 128-n positions if 128-n>=0, else no shift */
return _mm_or_si128(lo,hi); /* Merge hi and lo */
}
SSSE3 The SSSE3 case is more interesting. The pshufb
instruction is used as a small lookup table. It took me some time to figure out the right combination of the (saturated) arithmetic and the constants.
__m128i bit_mask_SSSE3(unsigned int n){ /* Create an __m128i vector with the n most significant bits set to 1 */
__m128i sat_const = _mm_set_epi8(247,239,231,223, 215,207,199,191, 183,175,167,159, 151,143,135,127); /* Constant used in combination with saturating addition */
__m128i sub_const = _mm_set1_epi8(248);
__m128i pshub_lut = _mm_set_epi8(0,0,0,0, 0,0,0,0,
0b11111111, 0b11111110, 0b11111100, 0b11111000,
0b11110000, 0b11100000, 0b11000000, 0b10000000);
__m128i shift_bc = _mm_set1_epi8(n); /* Broadcast n to the 16 8-bit elements. */
__m128i shft_byte = _mm_adds_epu8(shift_bc,sat_const); /* The constants sat_const and sub_const are selected such that */
__m128i shuf_indx = _mm_sub_epi8(shft_byte,sub_const); /* _mm_shuffle_epi8 can be used as a tiny lookup table */
return _mm_shuffle_epi8(pshub_lut,shuf_indx); /* which finds the right bit pattern at the right position. */
}
Functionality
For 1<=n<=128
, which was specified by the OP, the functions bit_mask_Paul_R(n)
(Paul R's answer), and bit_mask_v2(n)
produce the same results:
bit_mask_Paul_R( 0) = FFFFFFFFFFFFFFFF 0000000000000000
bit_mask_Paul_R( 1) = 8000000000000000 0000000000000000
bit_mask_Paul_R( 2) = C000000000000000 0000000000000000
bit_mask_Paul_R( 3) = E000000000000000 0000000000000000
.....
bit_mask_Paul_R(126) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFC
bit_mask_Paul_R(127) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFE
bit_mask_Paul_R(128) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFF
bit_mask_v2( 0) = 0000000000000000 0000000000000000
bit_mask_v2( 1) = 8000000000000000 0000000000000000
bit_mask_v2( 2) = C000000000000000 0000000000000000
bit_mask_v2( 3) = E000000000000000 0000000000000000
.....
bit_mask_v2(126) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFC
bit_mask_v2(127) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFE
bit_mask_v2(128) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFF
bit_mask_SSSE3( 0) = 0000000000000000 0000000000000000
bit_mask_SSSE3( 1) = 8000000000000000 0000000000000000
bit_mask_SSSE3( 2) = C000000000000000 0000000000000000
bit_mask_SSSE3( 3) = E000000000000000 0000000000000000
.....
bit_mask_SSSE3(126) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFC
bit_mask_SSSE3(127) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFE
bit_mask_SSSE3(128) = FFFFFFFFFFFFFFFF FFFFFFFFFFFFFFFF
For n=0
the most reasonable result is the zero vector, which is produced by bit_mask_v2(n)
and bit_mask_SSSE3(n)
.
Performance
To get a rough impression of the performance of the different functions, the following piece of code is used:
__m128i sum = _mm_setzero_si128();
for (i=0;i<1000000000;i=i+1){
sum=_mm_add_epi64(sum,bit_mask_Paul_R(i)); // or use next line instead
// sum=_mm_add_epi64(sum,bit_mask_v2(i));
// sum=_mm_add_epi64(sum,bit_mask_SSSE3(i));
}
_mm_storeu_si128((__m128i*)x,sum);
printf("sum = %016lX %016lX\n", x[1],x[0]);
The performance of the code depends slightly on the type of instruction encoding. GCC options opts1 = -O3 -m64 -Wall -march=nehalem
lead to non-vex encoded sse instructions, while opts2 = -O3 -m64 -Wall -march=sandybridge
compiles to vex encoded avx128 instructions.
The results with gcc 5.4 are:
Cycles per iteration on Intel Skylake, estimated with: perf stat -d ./a.out
opts1 opts2
bit_mask_Paul_R 6.0 7.0
bit_mask_v2 3.8 3.3
bit_mask_SSSE3 3.0 3.0
In practice the performance will depend on the cpu type and the surrounding code. The performance of bit_mask_SSSE3
is limited by port 5 pressure; three instructions (one movd
and the two pshufb
-s) per iteration are handled by port 5.
With AVX2, a more efficient code is possible, see here .
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.