简体   繁体   中英

Understanding binary predicates in Thrust

In response to my previous question , someone gave me the following code:

thrust::device_vector<bool> bools;
thrust::device_vector<float> values;

typedef thrust::device_vector<bool>::iterator   BoolIterator;
typedef thrust::device_vector<float>::iterator  ValueIterator;
typedef thrust::tuple<BoolIterator, ValueIterator> IteratorTuple;
typedef thrust::zip_iterator<IteratorTuple> ZipIterator;

ZipIterator iter_begin(thrust::make_tuple(bools.begin(), values.begin()));
ZipIterator iter_end(thrust::make_tuple(bools.end(), values.end()));

struct Predicate
{
  __host__ __device__ bool operator () 
                      (const IteratorTuple& lhs, const IteratorTuple& lhs) 
  {
    if (get<0>(lhs) && get<0>(rhs) ) return get<1>(lhs) <= get<1>(rhs); else
    return ! get<0>(lhs) ;
  }
};

ZipIterator result =  thrust::max_element(iter_begin, iter_end, Predicate()); 

I want to understand the Predicate struct. what happens if operator returns false? Which value gets selected? What happens if operator returns true? Which value gets selected?

I tried to implement 'lesser than' predicate. It should return true if lhs <= rhs and false otherwise. Additionally you requested to exclude values by boolean flags stored in second array, so it checks it.

From my comment:

I guess I overoptimized the code. This is 'less than' predicate. if condition evaluates false means one or bool flags are false , so we need to exclude corresponding value. So we check if lhs argument should be excluded (thrust::get<0>(lhs) == false) and if that is true predicate return true, meaning 'lhs is lesser than rhs'. If (thrust::get<0>(lhs) == true) , than rhs component should be excluded and predicate returns false, meaning 'lhs is not lesser than rhs'

I collapsed the following code:

using thrust::get;
if (get<0>(lhs) && get<0>(rhs) ) return get<1>(lhs) <= get<1>(rhs); else 
// we need co check which value should be excluded from the seach
if (get<0>(lhs) == false) // lhs should be excluded so lhs is lesser
                          // OR both should be excluded and no matter what 
                          // we will return it will be eliminated in other comparison
  return true; else
if (get<0>(rhs) == false) // rhs should be excluded so rhs is lesser
  return false;

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM