简体   繁体   中英

How to specialize a template function for different data types in which the procedures are similar?

For example, I want to implement a matrix multiplication template function using AVX2. (Suppose "Matrix" is a well implemented template class)

Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    if (typeid(T).name() == typeid(float).name()) {
        //using __m256 to store float
        //using __m256_load_ps __m256_mul_ps __m256_add_ps
    } else if (typeid(T).name() == typeid(double).name()) {
        //using __m256d to store double
        //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
    } else {
        //...
    }
}

As there is no "variable" for data types, the program can't determine whether it should use __m256 or __m256d or anything else, thus making the code very long and awkward. Is there another way to avoid this?

In C++17 and later, you can use if constexpr :

#include <type_traits>

Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    if constexpr (std::is_same_v<T, float>) {
        //using __m256 to store float
        //using __m256_load_ps __m256_mul_ps __m256_add_ps
    } else if constexpr (std::is_same_v<T, double>) {
        //using __m256d to store double
        //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
    } else {
        //...
    }
}

Otherwise, just use overloads:

Matrix<float> matmul(const Matrix<float>& mat1, const Matrix<float>& mat2) {
    //using __m256 to store float
    //using __m256_load_ps __m256_mul_ps __m256_add_ps
}

Matrix<double> matmul(const Matrix<double>& mat1, const Matrix<double>& mat2) {
    //using __m256d to store double
    //using __m256d_load_pd __m256d_mul_pd __m256d_add_pd
}

...

First, you could create overloads for the functions _mm256_load_* and _mm256_mul_* etc.:

namespace avx {
inline __m256 mm256_load(float const* a) {
    return _mm256_load_ps(a);
}
inline __m256d mm256_load(double const* a) {
    return _mm256_load_pd(a);
}

inline __m256 mm256_mul(__m256 m1, __m256 m2) {
    return _mm256_mul_ps(m1, m2);
}
inline __m256d mm256_mul(__m256d m1, __m256d m2) {
    return _mm256_mul_pd(m1, m2);
}

// add more avx functions here
} // namespace avx

You could then create a type trait to give you the proper AVX type for float and double :

#include <type_traits>

namespace avx {

template<class T> struct floatstore;
template<> struct floatstore<float> { using type = __m256; };
template<> struct floatstore<double> { using type = __m256d; };

template<class T>
using floatstore_t = typename floatstore<T>::type;
} // namespace avx

Your final function could then use the above overloaded functions and type traits and you will not need any runtime checks like you have in your original:

template<class T>
Matrix<T> matmul(const Matrix<T>& mat1, const Matrix<T>& mat2) {
    T floats[256/(sizeof(T)*CHAR_BIT)] = ...; // T is float or double
    avx::floatstore_t<T> a_variable;          // __m256 or __m256d

    // uses the proper overload for the type:
    a_variable = avx::mm256_load(floats)
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM