簡體   English   中英

如何針對程序相似的不同數據類型特化模板function?

[英]How to specialize a template function for different data types in which the procedures are similar?

比如我想用AVX2實現一個矩陣乘法模板function。 (假設“矩陣”是一個實現良好的模板類)

Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    if (typeid(T).name() == typeid(float).name()) {
        //using __m256 to store float
        //using __m256_load_ps __m256_mul_ps __m256_add_ps
    } else if (typeid(T).name() == typeid(double).name()) {
        //using __m256d to store double
        //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
    } else {
        //...
    }
}

由於沒有數據類型的“變量”,程序無法確定它是否應該使用 __m256 或 __m256d 或其他任何東西,從而使代碼非常長且笨拙。 還有另一種方法可以避免這種情況嗎?

在 C++17 及更高版本中,您可以使用if constexpr

#include <type_traits>

Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    if constexpr (std::is_same_v<T, float>) {
        //using __m256 to store float
        //using __m256_load_ps __m256_mul_ps __m256_add_ps
    } else if constexpr (std::is_same_v<T, double>) {
        //using __m256d to store double
        //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
    } else {
        //...
    }
}

否則,只需使用重載:

Matrix<float> matmul(const Matrix<float>& mat1, const Matrix<float>& mat2) {
    //using __m256 to store float
    //using __m256_load_ps __m256_mul_ps __m256_add_ps
}

Matrix<double> matmul(const Matrix<double>& mat1, const Matrix<double>& mat2) {
    //using __m256d to store double
    //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
}

...

首先,您可以為函數_mm256_load_*_mm256_mul_*等創建重載:

namespace avx {
inline __m256 mm256_load(float const* a) {
    return _mm256_load_ps(a);
}
inline __m256d mm256_load(double const* a) {
    return _mm256_load_pd(a);
}

inline __m256 mm256_mul(__m256 m1, __m256 m2) {
    return _mm256_mul_ps(m1, m2);
}
inline __m256d mm256_mul(__m256d m1, __m256d m2) {
    return _mm256_mul_pd(m1, m2);
}

// add more avx functions here
} // namespace avx

然后,您可以創建類型特征以為floatdouble提供正確的 AVX 類型:

#include <type_traits>

namespace avx {

template<class T> struct floatstore;
template<> struct floatstore<float> { using type = __m256; };
template<> struct floatstore<double> { using type = __m256d; };

template<class T>
using floatstore_t = typename floatstore<T>::type;
} // namespace avx

你的最終 function 然后可以使用上面的重載函數和類型特征,你不需要像你原來的那樣進行任何運行時檢查:

template<class T>
Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    T floats[256/(sizeof(T)*CHAR_BIT)] = ...; // T is float or double
    avx::floatstore_t<T> a_variable;          // __m256 or __m256d

    // uses the proper overload for the type:
    a_variable = avx::mm256_load(floats)
}

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM