[英]Cuda to make Matrix Multiplication
使用cuda进行矩阵乘法时遇到问题。 我必须做A * A * A * A并将其保存在hB中。 使用Cublas可以,但是我不能使用CUDA。 Dimension可能是一个很高的值,例如2000。这是我的代码:
__global__ void CudaMM(float *A, float *B, int N)
{
int row = blockIdx.y*blockDim.y + threadIdx.y;
int col = blockIdx.x*blockDim.x + threadIdx.x;
float sum = 0.f;
for (int n = 0; n < N; ++n)
sum += A[row*N+n]*A[n*N+col];
B[row*N+col] = sum;
}
void CudaMult(int dimension,float *hMatrice,float *hB,float *d_A,float *d_B){
int N,K;
K = 100;
N = K*BLOCK_SIZE;
dim3 threadBlock(BLOCK_SIZE,BLOCK_SIZE);
dim3 grid(K,K);
cudaMemcpy(d_A,hMatrice,dimension*dimension*sizeof(float),cudaMemcpyHostToDevice);
CudaMM<<<grid,threadBlock>>>(d_A,d_B,N);
cudaMemcpy(hB,d_B,dimension*dimension*sizeof(float),cudaMemcpyDeviceToHost);
}
void CublasFindConnect(int dimension,float* mat,float* B){
float *d_A,*d_B;
cudaMalloc(&d_A,dimension*dimension*sizeof(float));
cudaMalloc(&d_B,dimension*dimension*sizeof(float));
int w=0;
while(w<5){
CudaMult(dimension,mat,B,d_A,d_B);
// Copy Matrix computed B to previous M
for (m=0; m<dimension; m++) {
for (n=0; n<dimension; n++) {
mat[m*dimension+n]=B[m*dimension+n];
B[m*dimension+n]=0;
}
}
w++;
}
cudaFree(d_A);
cudaFree(d_B);
}
我安装了最新的CUDA 6,它不需要cudaMemCpy,因为共享内存。
BLOCK_SIZE
? 这个想法不是要告诉我BLOCK_SIZE
是什么,而是要显示完整的代码。 cudaMallocManaged()
,您在CUDA 6中引用的功能具有您没有满足的特定要求(例如使用cudaMallocManaged()
),但是您的代码并不依赖于统一内存,因此无关紧要。 我在您的代码中看到的一个问题是您的dimension
变量是任意的(您可以说它最多可以是2000),但是您的计算大小固定为N=K*BLOCK_SIZE;
。 大概,如果您的BLOCK_SIZE是某个值,例如16或32,那么它将满足您大约2000的最大dimension
。
出现问题是因为您的网格大小可能大于有效数组的大小。 您正在启动N
x N
网格,但是N
可以大于dimension
。 这意味着某些启动的线程可以尝试在其有效维之外访问矩阵( A
和B
)。
您可以通过内核中的“线程检查”来解决此问题,如下所示:
__global__ void CudaMM(float *A, float *B, int N)
{
int row = blockIdx.y*blockDim.y + threadIdx.y;
int col = blockIdx.x*blockDim.x + threadIdx.x;
if ((row < N) && (col < N)) {
float sum = 0.f;
for (int n = 0; n < N; ++n)
sum += A[row*N+n]*A[n*N+col];
B[row*N+col] = sum;
}
}
并且您需要将内核调用修改为:
CudaMM<<<grid,threadBlock>>>(d_A,d_B,dimension);
您可能还需要考虑根据实际dimension
选择网格大小,而不是固定为100*BLOCK_SIZE
,但这对于使代码正常工作不是必需的。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.