简体   繁体   English

如何投射二维推力:: device_vector <thrust::device_vector<int> &gt;指向原始指针

[英]how to cast a 2-dimensional thrust::device_vector<thrust::device_vector<int>> to raw pointer

When I use the thrust::device_vector in main function,I can pass it to the kernel function correctly,the code as follows: 在主函数中使用推力:: device_vector时,可以正确地将其传递给内核函数,代码如下:

 thrust::device_vector<int> device_a(2);
 thrust::host_vector<int> host_a(2);
 MyTest << <1, 2 >> >(thrust::raw_pointer_cast(&device_a[0]),device_a.size());
 host_a = device_a;
 for (int i = 0; i < host_a.size();i++)
 cout << host_a[i] << endl;

but I want to use 2-dimension device_vector in my code,How can I use it? 但我想在代码中使用二维device_vector,如何使用? as shown i the following code 如我所示以下代码

__global__ void MyTest(thrust::device_vector<int>* a,int total){  
    int idx = threadIdx.x;  
    if (idx < total){
        int temp = idx;  
        a[idx][0] = temp;  
        a[idx][1] = temp; 
        __syncthreads();
      }

}  
 void main(){
    thrust::device_vector<thrust::device_vector<int>> device_a(2,thrust::device_vector<int>(2));

    thrust::host_vector<thrust::host_vector<int>> host_a(2,thrust::host_vector<int>(2));

    MyTest << <1, 2 >> >(thrust::raw_pointer_cast(device_a.data()),device_a.size());
    host_a = device_a;
    for (int i = 0; i < host_a.size(); i++){
    cout << host_a[i][0] << endl;
    cout << host_a[i][1] << endl;
}
}

Generally, Thrust containers are host only types that can not be used in __device__ and __global__ functions. 通常,Thrust容器是仅主机类型,不能在__device____global__函数中使用。

The common way to use 2-D array is to put it in a 1-D linear memory space like the following code. 使用二维数组的常见方法是将其放入一维线性存储空间,如以下代码所示。

__global__ void MyTest(int* a, int nrows, int ncols) {
  int j = threadIdx.x;
  int i = threadIdx.y;
  if (i < nrows && j < ncols) {
    int temp = i + j;
    a[i * ncols + j] = temp;
  }

}

int main(int argc, char** argv) {
  int nrows = 2;
  int ncols = 2;
  thrust::device_vector<int> device_a(nrows * ncols);
  MyTest<<<1, dim3(2, 2)>>>(thrust::raw_pointer_cast(device_a.data()), rows, ncols);
  return 0;
}

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM