繁体   English   中英

C ++结构函子作为函数模板参数

[英]C++ struct functor as function template parameter

我很困惑,我知道CUDA和其他库允许使用模板结构作为函子。 因此,我为神经网络类设计了其中一些:

struct sigmoid
{
     sigmoid()=default;                                                           
     __device__ float operator()(const float x) const                                                                                     
    {                                                                                                                                     
         float exp_val = __expf(-x);                                                                                                       
         float denom = __fadd_rz(1.f,exp_val);                                                                                             
         return __fdividef(1.f,denom);                                                
    }                                                                     
};       

当我将其用于CUDA内核时,其用法有些简单:

activate<sigmoid><<num_blocks_x,block_threads_x>>>(sigmoid(),output_ptr);

对于:

template <typename F>                                                                                                                     
__global__ void activate(F const& func, float * input)                                                                                    
{                                                                                                                                         
   int x = blockIdx.x * blockDim.x + threadIdx.x;                                                                                        
   input[x]  = func(input[x]);                                                                                                           
} 

但是我想将函数模板包装在调用CUDA内核的方法周围 ,然后将其转发给它:

template <class A>                                                                                                             
thrust::host_vector<float> propagate (                                                                                                
                                       A func,                                                                 
                                       thrust::device_vector<float> & input                                                          
                                     ) const; 

我已经将它实现到一个单独的标头中,该标头包含在声明该类的标头的末尾。

class ann
{
...
};
#include ann_imp.hpp

和imp标头:

template <class A> inline                                                                                                                   
__host__ thrust::host_vector<float> ann::propagate (                                                                                        
                                                       A func,                                                                            
                                                       thrust::device_vector<float> & input                                               
                                                    ) const                                                                                 
{
     activate<func><<<num_blocks_x,block_threads_x>>>(func(),output_ptr);
}  

但是,当我调用实际的propagate方法时,我遇到了麻烦:

net.propagate<sigmoid>( sigmoid(), in_vec1 );

产生:

error: function "sigmoid::operator()" cannot be called with the given argument list
            object type is: sigmoid

当我不使用operator()而是仅使用类型名时:

xor_net.propagate<sigmoid>( sigmoid, in_vec1 );

我得到:

error: type name is not allowed

使用实际对象会产生相同的错误:

sigmoid func;
xor_net.propagate<sigmoid>( func, in_vec1 );

我尝试使用参数为A const& func等,但无济于事。

如何传递结构函子,然后将其转发到CUDA内核?

编辑没有包装,调用激活功能根本要求:

activate<sigmoid><<<num_blocks_x,block_threads_x>>>(sigmoid(),output_ptr); 

你有:

 __device__ float operator()(const float x) const ...

该函数需要一个float类型的参数。 您从ann::propagate调用它为:

activate<func><<<num_blocks_x,block_threads_x>>>(func(),output_ptr);
                                                 ^^^^^^

我认为这行必须是:

activate<A><<<num_blocks_x,block_threads_x>>>(func,output_ptr);
       ^^^^                                   ^^^^^     
       Fix the type                           Use the object.

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM