简体   繁体   中英

C++ struct functor as function template parameter

I am confused, I know that CUDA and other libraries allow the usage of a template struct as the functor. Thus I have designed a few of them for a neural network class:

struct sigmoid
{
     sigmoid()=default;                                                           
     __device__ float operator()(const float x) const                                                                                     
    {                                                                                                                                     
         float exp_val = __expf(-x);                                                                                                       
         float denom = __fadd_rz(1.f,exp_val);                                                                                             
         return __fdividef(1.f,denom);                                                
    }                                                                     
};       

When I use this for a CUDA kernel, its usage is somewhat straightforward:

activate<sigmoid><<num_blocks_x,block_threads_x>>>(sigmoid(),output_ptr);

For:

template <typename F>                                                                                                                     
__global__ void activate(F const& func, float * input)                                                                                    
{                                                                                                                                         
   int x = blockIdx.x * blockDim.x + threadIdx.x;                                                                                        
   input[x]  = func(input[x]);                                                                                                           
} 

However I want to wrap the function template around the method that calls the CUDA kernel, and then forward it to it:

template <class A>                                                                                                             
thrust::host_vector<float> propagate (                                                                                                
                                       A func,                                                                 
                                       thrust::device_vector<float> & input                                                          
                                     ) const; 

I've implemented it into a separate header, which is being included at the end of the header which declares the class.

class ann
{
...
};
#include ann_imp.hpp

And the imp header:

template <class A> inline                                                                                                                   
__host__ thrust::host_vector<float> ann::propagate (                                                                                        
                                                       A func,                                                                            
                                                       thrust::device_vector<float> & input                                               
                                                    ) const                                                                                 
{
     activate<func><<<num_blocks_x,block_threads_x>>>(func(),output_ptr);
}  

Yet when I call the actual propagate method I run into trouble:

net.propagate<sigmoid>( sigmoid(), in_vec1 );

Produces:

error: function "sigmoid::operator()" cannot be called with the given argument list
            object type is: sigmoid

When I don't use the operator() but only the typename:

xor_net.propagate<sigmoid>( sigmoid, in_vec1 );

I get:

error: type name is not allowed

Using an actual object yields the same error:

sigmoid func;
xor_net.propagate<sigmoid>( func, in_vec1 );

I've tried playing around with the parameter being A const& func and such, but to no avail.

How do I pass a struct functor, and then forward it to the CUDA kernel?

EDIT Without the wrapper, calling the activation function simply required:

activate<sigmoid><<<num_blocks_x,block_threads_x>>>(sigmoid(),output_ptr); 

You have:

 __device__ float operator()(const float x) const ...

The function needs an argument of type float . You are calling it from ann::propagate as:

activate<func><<<num_blocks_x,block_threads_x>>>(func(),output_ptr);
                                                 ^^^^^^

I believe that line needs to be:

activate<A><<<num_blocks_x,block_threads_x>>>(func,output_ptr);
       ^^^^                                   ^^^^^     
       Fix the type                           Use the object.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM