I am performing a series of calculations on a large number of threads using C++ AMP. The last step of the calculation though is to prune the result but only for a limited number of threads. For example, if the result of the calculation is below a threshold, then set the result to 0 BUT only do this for a maximum of X threads. Essentially this is a shared counter but also a shared conditional check.
Any help is appreciated!
My understanding of your question is the following pseudo-code performed by each thread:
auto result = ...
if(result < global_threshold) // if the result of the calculation is below a threshold
if(global_counter++ < global_max) // for a maximum of X threads
result = 0; // then set the result to 0
store(result);
I then further assume that both global_threshold
and global_max
does not change during the computation (ie between parallel_for_each
start and finish) - so the most elegant way to pass them is through lambda capture.
On the other hand, global_counter
clearly changes value, so it must be located in modifiable memory shared across all threads, effectively being array<T,N>
or array_view<T,N>
. Since the threads incrementing this object are not synchronized, the operation would need to be performed using atomic operation.
The above translates to the following C++ AMP code (I'm using Visual Studio 2013 syntax, but it is easily back-portable to Visual Studio 2012):
std::vector<int> result_storage(1024);
array_view<int> av_result{ result_storage };
int global_counter_storage[1] = { 0 };
array_view<int> global_counter{ global_counter_storage };
int global_threshold = 42;
int global_max = 3;
parallel_for_each(av_result.extent, [=](index<1> idx) restrict(amp)
{
int result = (idx[0] % 50) + 1; // 1 .. 50
if(result < global_threshold)
{
// assuming less than INT_MAX threads will enter here
if(atomic_fetch_inc(&global_counter[0]) < global_max)
{
result = 0;
}
}
av_result[idx] = result;
});
av_result.synchronize();
auto zeros = count(begin(result_storage), end(result_storage), 0);
std::cout << "Total number of zeros in results: " << zeros << std::endl
<< "Total number of threads lower than threshold: " << global_counter[0]
<< std::endl;
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.