[英]Bit vector operation with AVX2 and SSE2
我是 AVX2 和 SSE2 指令集的新手,我想了解更多關於如何使用這些指令集來加速位向量操作的信息。
到目前為止,我已經成功地使用它們通過雙/浮點操作對代碼進行了矢量化。
在此示例中,我有一個 C++ 代碼,該代碼在將位向量(使用 unsigned int)中的某個位設置或不設置為特定值之前檢查條件:
int process_bit_vetcor(unsigned int *bitVector, float *value, const float threshold, const unsigned int dim)
{
int sum = 0, cond = 0;
for (unsigned int i = 0; i < dim; i++) {
unsigned int *word = bitVector + i / 32;
unsigned int bitValue = ((unsigned int)0x80000000 >> (i & 0x1f));
cond = (value[i] <= threshold);
(*word) = (cond) ? (*word) | bitValue : (*word);
sum += cond;
}
return sum;
}
變量sum僅返回條件為 TRUE 的情況數。
我試圖用 SSE2 和 AVX2 重寫這個例程,但沒有成功...... :-(
是否可以使用 AVX2 和 SSE2 重寫這樣的 C++ 代碼? 對這種類型的位操作使用矢量化是否值得? 位向量可能包含數千位,因此我希望使用 SSE2 和 AVX2 來加速可能會很有趣。
提前致謝!
如果dim
是 8 的倍數,則以下內容應該有效(要處理余數,請在末尾添加一個簡單的循環)。 較小的 API 更改:
long
而不是unsigned int
作為循環索引(這有助於 clang 展開循環)bitvector
是小端(如評論中所建議) 在循環內部,按字節訪問bitVector
。 一次將 2 或 4 個movemask
和 bit-or 結果組合起來可能是值得的(可能取決於目標架構)。
要計算sum
,直接從cmp_ps
操作的結果計算 8 個部分和。 由於無論如何您都需要位掩碼,因此可能值得使用popcnt
(理想情況下,在將 2、4 或 8 個字節組合在一起之后——同樣,這可能取決於您的目標體系結構)。
int process_bit_vector(uint32_t *bitVector32, float *value,
const float threshold_float, const long dim) {
__m256i sum = _mm256_setzero_si256();
__m256 threshold_vector = _mm256_set1_ps(threshold_float);
uint8_t *bitVector8 = (uint8_t *)bitVector32;
for (long i = 0; i <= dim-8; i += 8) {
// compare next 8 values with threshold
// (use threshold as first operand to allow loading other operand from memory)
__m256 cmp_mask = _mm256_cmp_ps(threshold_vector, _mm256_loadu_ps(value + i), _CMP_GE_OQ);
// true values are `-1` when interpreted as integers, subtract those from `sum`
sum = _mm256_sub_epi32(sum, _mm256_castps_si256(cmp_mask));
// extract bitmask
int mask = _mm256_movemask_ps(cmp_mask);
// bitwise-or current mask with result bit-vector
*bitVector8++ |= mask;
}
// reduce 8 partial sums to a single sum and return
__m128i sum_reduced = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extracti128_si256(sum,1));
sum_reduced = _mm_add_epi32(sum_reduced, _mm_srli_si128(sum_reduced, 8));
sum_reduced = _mm_add_epi32(sum_reduced, _mm_srli_si128(sum_reduced, 4));
return _mm_cvtsi128_si32(sum_reduced);
}
Godbolt-Link: https://godbolt.org/z/ABwDPe
vpsubd ymm2, ymm0, ymm1; vmovdqa ymm0, ymm2;
vpsubd ymm2, ymm0, ymm1; vmovdqa ymm0, ymm2;
而不僅僅是vpsubd ymm0, ymm0, ymm1
。vcmpps
加入load
(並使用LE
而不是GE
比較)——如果您不關心 NaN 的處理方式,您可以使用_CMP_NLT_US
而不是_CMP_GE_OQ
。大端 output 的修訂版本(未經測試):
int process_bit_vector(uint32_t *bitVector32, float *value,
const float threshold_float, const long dim) {
int sum = 0;
__m256 threshold_vector = _mm256_set1_ps(threshold_float);
for (long i = 0; i <= dim-32; i += 32) {
// compare next 4x8 values with threshold
// (use threshold as first operand to allow loading other operand from memory)
__m256i cmp_maskA = _mm256_castps_si256(_mm256_cmp_ps(threshold_vector, _mm256_loadu_ps(value + i+ 0), _CMP_GE_OQ));
__m256i cmp_maskB = _mm256_castps_si256(_mm256_cmp_ps(threshold_vector, _mm256_loadu_ps(value + i+ 8), _CMP_GE_OQ));
__m256i cmp_maskC = _mm256_castps_si256(_mm256_cmp_ps(threshold_vector, _mm256_loadu_ps(value + i+16), _CMP_GE_OQ));
__m256i cmp_maskD = _mm256_castps_si256(_mm256_cmp_ps(threshold_vector, _mm256_loadu_ps(value + i+24), _CMP_GE_OQ));
__m256i cmp_mask = _mm256_packs_epi16(
_mm256_packs_epi16(cmp_maskA,cmp_maskB), // b7b7b6b6'b5b5b4b4'a7a7a6a6'a5a5a4a4 b3b3b2b2'b1b1b0b0'a3a3a2a2'a1a1a0a0
_mm256_packs_epi16(cmp_maskC,cmp_maskD) // d7d7d6d6'd5d5d4d4'c7c7c6c6'c5c5c4c4 d3d3d2d2'd1d1d0d0'c3c3c2c2'c1c1c0c0
); // cmp_mask = d7d6d5d4'c7c6c5c4'b7b6b5b4'a7a6a5a4 d3d2d1d0'c3c2c1c0'b3b2b1b0'a3a2a1a0
cmp_mask = _mm256_permute4x64_epi64(cmp_mask, 0x8d);
// cmp_mask = [b7b6b5b4'a7a6a5a4 b3b2b1b0'a3a2a1a0 d7d6d5d4'c7c6c5c4 d3d2d1d0'c3c2c1c0]
__m256i shuff_idx = _mm256_broadcastsi128_si256(_mm_set_epi64x(0x00010203'08090a0b,0x04050607'0c0d0e0f));
cmp_mask = _mm256_shuffle_epi8(cmp_mask, shuff_idx);
// extract bitmask
uint32_t mask = _mm256_movemask_epi8(cmp_mask);
sum += _mm_popcnt_u32 (mask);
// bitwise-or current mask with result bit-vector
*bitVector32++ |= mask;
}
return sum;
}
這個想法是在對其應用vpmovmskb
之前對字節進行洗牌。 對於 32 個輸入值,這需要 5 次 shuffle 操作(包括 3 個vpacksswb
),但總和的計算是使用popcnt
而不是 4 vpsubd
的。 vpermq
( _mm256_permute4x64_epi64
) 可以通過在比較它們之前策略性地將 128 位半加載到 256 位向量中來避免。 另一個想法(因為無論如何您都需要對最終結果進行洗牌)將部分結果混合在一起(這往往需要p5
或2*p015
在我檢查過的架構上,所以可能不值得)。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.