[英]How to specialize a template function for different data types in which the procedures are similar?
比如我想用AVX2實現一個矩陣乘法模板function。 (假設“矩陣”是一個實現良好的模板類)
Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
if (typeid(T).name() == typeid(float).name()) {
//using __m256 to store float
//using __m256_load_ps __m256_mul_ps __m256_add_ps
} else if (typeid(T).name() == typeid(double).name()) {
//using __m256d to store double
//using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
} else {
//...
}
}
由於沒有數據類型的“變量”,程序無法確定它是否應該使用 __m256 或 __m256d 或其他任何東西,從而使代碼非常長且笨拙。 還有另一種方法可以避免這種情況嗎?
在 C++17 及更高版本中,您可以使用if constexpr
:
#include <type_traits>
Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
if constexpr (std::is_same_v<T, float>) {
//using __m256 to store float
//using __m256_load_ps __m256_mul_ps __m256_add_ps
} else if constexpr (std::is_same_v<T, double>) {
//using __m256d to store double
//using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
} else {
//...
}
}
否則,只需使用重載:
Matrix<float> matmul(const Matrix<float>& mat1, const Matrix<float>& mat2) {
//using __m256 to store float
//using __m256_load_ps __m256_mul_ps __m256_add_ps
}
Matrix<double> matmul(const Matrix<double>& mat1, const Matrix<double>& mat2) {
//using __m256d to store double
//using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
}
...
首先,您可以為函數_mm256_load_*
和_mm256_mul_*
等創建重載:
namespace avx {
inline __m256 mm256_load(float const* a) {
return _mm256_load_ps(a);
}
inline __m256d mm256_load(double const* a) {
return _mm256_load_pd(a);
}
inline __m256 mm256_mul(__m256 m1, __m256 m2) {
return _mm256_mul_ps(m1, m2);
}
inline __m256d mm256_mul(__m256d m1, __m256d m2) {
return _mm256_mul_pd(m1, m2);
}
// add more avx functions here
} // namespace avx
然后,您可以創建類型特征以為float
和double
提供正確的 AVX 類型:
#include <type_traits>
namespace avx {
template<class T> struct floatstore;
template<> struct floatstore<float> { using type = __m256; };
template<> struct floatstore<double> { using type = __m256d; };
template<class T>
using floatstore_t = typename floatstore<T>::type;
} // namespace avx
你的最終 function 然后可以使用上面的重載函數和類型特征,你不需要像你原來的那樣進行任何運行時檢查:
template<class T>
Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
T floats[256/(sizeof(T)*CHAR_BIT)] = ...; // T is float or double
avx::floatstore_t<T> a_variable; // __m256 or __m256d
// uses the proper overload for the type:
a_variable = avx::mm256_load(floats)
}
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.