I am confused, I know that CUDA and other libraries allow the usage of a template struct as the functor. Thus I have designed a few of them for a neural network class:
struct sigmoid
{
sigmoid()=default;
__device__ float operator()(const float x) const
{
float exp_val = __expf(-x);
float denom = __fadd_rz(1.f,exp_val);
return __fdividef(1.f,denom);
}
};
When I use this for a CUDA kernel, its usage is somewhat straightforward:
activate<sigmoid><<num_blocks_x,block_threads_x>>>(sigmoid(),output_ptr);
For:
template <typename F>
__global__ void activate(F const& func, float * input)
{
int x = blockIdx.x * blockDim.x + threadIdx.x;
input[x] = func(input[x]);
}
However I want to wrap the function template around the method that calls the CUDA kernel, and then forward it to it:
template <class A>
thrust::host_vector<float> propagate (
A func,
thrust::device_vector<float> & input
) const;
I've implemented it into a separate header, which is being included at the end of the header which declares the class.
class ann
{
...
};
#include ann_imp.hpp
And the imp header:
template <class A> inline
__host__ thrust::host_vector<float> ann::propagate (
A func,
thrust::device_vector<float> & input
) const
{
activate<func><<<num_blocks_x,block_threads_x>>>(func(),output_ptr);
}
Yet when I call the actual propagate
method I run into trouble:
net.propagate<sigmoid>( sigmoid(), in_vec1 );
Produces:
error: function "sigmoid::operator()" cannot be called with the given argument list
object type is: sigmoid
When I don't use the operator()
but only the typename:
xor_net.propagate<sigmoid>( sigmoid, in_vec1 );
I get:
error: type name is not allowed
Using an actual object yields the same error:
sigmoid func;
xor_net.propagate<sigmoid>( func, in_vec1 );
I've tried playing around with the parameter being A const& func
and such, but to no avail.
How do I pass a struct functor, and then forward it to the CUDA kernel?
EDIT Without the wrapper, calling the activation function simply required:
activate<sigmoid><<<num_blocks_x,block_threads_x>>>(sigmoid(),output_ptr);
You have:
__device__ float operator()(const float x) const ...
The function needs an argument of type float
. You are calling it from ann::propagate
as:
activate<func><<<num_blocks_x,block_threads_x>>>(func(),output_ptr);
^^^^^^
I believe that line needs to be:
activate<A><<<num_blocks_x,block_threads_x>>>(func,output_ptr);
^^^^ ^^^^^
Fix the type Use the object.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.