简体   繁体   中英

How do I detect if a variable is an armadillo array?

Is there a way to find whether a variable is an armadillo array?

I need to implement a function of this kind:

template<typename T>
T foo(T)
{
    T res;

    if(is_armadillo(T))
    {
        ...
    }
    else
    {
        ...
    }

    return res;
}

you can build a template metafunction:

#include <type_traits>
#include <armadillo>

// primary template - false in all other cases
template< typename, typename = void >
struct is_armadillo
      : std::false_type {};

// specification - matches if T has a t() method 
// (common to Mat, Col, Row, Cube, field and SpMat)
template< typename T >
struct is_armadillo< T, 
     typename std::enable_if<std::is_member_function_pointer<decltype(&T::t)>::value>::type >
: std::true_type {};

You can use it to create an element-wise maximum function:

// case 1: parameters are both arithmetic types
template< typename T1, typename T2,
    typename std::enable_if<std::is_arithmetic<T2>::value>::type* = nullptr>
T2 maximum(T1 th, T2 v)
{
    static_assert( std::is_arithmetic<T1>::value == true, 
            "First argument is not arithmentic" );
    // use the built-in multiplication operator
    return v*(v>=th) + th*(v<th);
}

// case 2: parameter two is an armadillo array
template< typename T1, typename T2,
    typename std::enable_if<is_armadillo<T2>::value>::type* = nullptr>
T2 maximum(T1 th, T2 v)
{
    static_assert( std::is_arithmetic<T1>::value == true, 
            "First argument is not arithmentic" );

    // use the element-wise multiplication operator
    return v%(v>=th) + th*(v<th);
}

A simple test:

using namespace std;
using namespace arma;

int main()
{
    double a = -0.6;
    vec v{-0.1,0.9,0.3,-1.6};

    double th = 0;
    cout << endl;

    cout << "original value:            ";
    cout << a << endl; 
    cout << "trunked to positive part:  ";
    cout << maximum(th,a) << endl;
    cout << endl; 

    cout << endl;
    cout << "original array:            ";
    v.t().raw_print(); 
    cout << "trunked to positive parts: ";
    maximum(th,v).t().raw_print();
    cout << endl;

    return 0;
}

output:

original value:            -0.6
trunked to positive part:  0    
original array:            -0.1 0.9 0.3 -1.6
trunked to positive parts: 0 0.9 0.3 0

You can use this:

#include <typeinfo>
...
cout << typeid(variable).name() << endl;

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM