简体   繁体   English

Cuda Thrust获取细分中的上一个元素

[英]Cuda Thrust get previous element in a segment

I've a vector of values and a vector of keys(indicating a segment) 我有一个值向量和一个键向量(指示段)

v = [1, 1, 1, 2, 3, 5, 6]
k = [0, 0, 1, 1, 0, 2, 2]

For each element I want to know its previous element(in the same segment). 对于每个元素,我想知道其先前的元素(在同一段中)。 It could be value or an index in the original vector doesn't matter. 它可以是值,也可以是原始向量中的索引不重要。

So result should be(in case of value) 因此结果应该是(在有价值的情况下)

r = [nan, 1, nan, 1, 1, nan, 5]

you can use any element instead of nan , for remaining part of an algorithm no matters. 您可以使用任何元素代替nan ,而不必考虑算法的其余部分。

Probably I can archive it with exclusive segmented scan and max operation instead of sum . 也许我可以使用排他的分段扫描和max操作(而不是sum其存档。 So two questions: 有两个问题:

  1. Is my approach correct? 我的方法正确吗?
  2. Is any more elegant or efficient solution? 还有更优雅或更有效的解决方案吗?

The desired functionality can be implemented using the following steps: 可以使用以下步骤来实现所需的功能:

  1. sort v by k to get equal key values next to each other; k排序v以得到彼此相等的键值; this has to be done through stable_sort_by_key as you want to retrieve the "previous" element, so ordering among elements with equal keys has to be preserved. 要检索“上一个”元素,必须通过stable_sort_by_key完成此操作,因此必须保留具有相同键的元素之间的顺序。

  2. apply the following transformation to the sorted data: 将以下转换应用于排序后的数据:

if (previous element has the same key) then return value of previous element else return -1


The following code implements those steps: 以下代码实现了这些步骤:

#include <cstdint>
#include <iostream>

#include <thrust/host_vector.h>
#include <thrust/device_vector.h>  
#include <thrust/sort.h>
#include <thrust/transform.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/zip_iterator.h>

#define PRINTER(name) print(#name, (name))
template <template <typename...> class V, typename T, typename ...Args>
void print(const char* name, const V<T,Args...> & v)
{
    std::cout << name << ":\t";
    thrust::copy(v.begin(), v.end(), std::ostream_iterator<T>(std::cout, "\t"));
    std::cout << std::endl;
}
template<typename... Iterators>
__host__ __device__
thrust::zip_iterator<thrust::tuple<Iterators...>> zip(Iterators... its)
{
    return thrust::make_zip_iterator(thrust::make_tuple(its...));
}

template <typename IteratorType, typename Integer>
struct prev_value
{
    prev_value(IteratorType first) : first(first){}

   template <typename Tuple>
   __host__ __device__
   Integer operator()(const Tuple& t)
   {
      const auto& index = thrust::get<0>(t);
      const auto& previousValue = thrust::get<1>(t);

      Integer result = -1;
      const auto& currentKey = *(first+index);
      const auto& previousKey = *(first+index-1);
      if(currentKey == previousKey)
      {
          result = previousValue;
      }

      return result;
   }

   IteratorType first;
};

template <typename Integer, typename IteratorType>
prev_value<IteratorType, Integer> make_prev_value(IteratorType first)
{
  return prev_value<IteratorType, Integer>(first);
}


int main(int argc, char** argv)
{
    using Integer = std::int32_t;
    using HostVec = thrust::host_vector<Integer>;
    using DeviceVec = thrust::device_vector<Integer>;

    Integer v[] = {1, 1, 1, 2, 3, 5, 6};
    Integer k[] = {0, 0, 1, 1, 0, 2, 2};

    Integer size = sizeof(k)/sizeof(k[0]);

    HostVec h_k(k, k+size);
    HostVec h_v(v, v+size);

    // copy data to device
    DeviceVec d_k = h_k;
    DeviceVec d_v = h_v;

    std::cout << "---- input data ----" << std::endl;
    PRINTER(d_k);    
    PRINTER(d_v);

    thrust::stable_sort_by_key(d_k.begin(), d_k.end(), d_v.begin());
    std::cout << "---- after sorting ----" << std::endl;
    PRINTER(d_k);    
    PRINTER(d_v);

    DeviceVec d_r(size, -1);
    auto op = make_prev_value<Integer>(d_k.begin());
    thrust::transform(zip(thrust::make_counting_iterator(Integer(1)), d_v.begin()),
                      zip(thrust::make_counting_iterator(size), d_v.end()),
                      d_r.begin()+1,
                      op);
    std::cout << "---- result ----" << std::endl;
    PRINTER(d_r);

    return 0;
}

output: 输出:

---- input data ----
d_k:    0   0   1   1   0   2   2   
d_v:    1   1   1   2   3   5   6   
---- after sorting ----
d_k:    0   0   0   1   1   2   2   
d_v:    1   1   3   1   2   5   6   
---- result ----
d_r:    -1  1   1   -1  1   -1  5

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM