繁体   English   中英

如何在PyOpenCL中覆盖数组元素

[英]How to overwrite array elements in PyOpenCL

我想用另一个数组覆盖PyOpenCL数组的一部分。 比方说

import numpy as np, pyopencl.array as cla
a = cla.zeros(queue,(3,3),'int8')
b = cla.ones(queue,(2,2),'int8')

现在我想做类似a[0:2,0:2] = b事情,并希望得到

1 1 0
1 1 0
0 0 0

出于速度原因,如何在不将所有内容复制到主机的情况下执行此操作?

Pyopencl数组能够做到这一点- 在回答这个问题的时候非常有限 -使用numpy语法(即确切地说,它是如何编写的)的局限性是:您只能沿第一个轴使用切片。

import numpy as np, pyopencl.array as cla

a = cla.zeros(queue,(3,3),'int8')
b = cla.ones(queue,(2,3),'int8')
# note b is 2x3 here 
a[0:2]=b #<-works
a[0:2,0:2]=b[:,0:2] #<-Throws an error about non-contiguity

因此, a[0:2,0:2] = b无效,因为目标切片数组具有不连续的数据。

我知道的唯一解决方案( 因为pyopencl.array类中的任何内容都无法使用切片数组/非连续数据 ),是编写自己的openCL内核来“手动”进行复制。

这是我用来在所有dtype的1D或2D pyopencl数组上进行复制的一段代码:

import numpy as np, pyopencl as cl, pyopencl.array as cla
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
kernel = cl.Program(ctx, """__kernel void copy(
            __global char *dest,      const int offsetd, const int stridexd, const int strideyd,
            __global const char *src, const int offsets, const int stridexs, const int strideys,
            const int word_size) {

            int write_idx = offsetd + get_global_id(0) + get_global_id(1) * stridexd + get_global_id(2) * strideyd ;
            int read_idx  = offsets + get_global_id(0) + get_global_id(1) * stridexs + get_global_id(2) * strideys;
            dest[write_idx] =  src[read_idx];

            }""").build()

def copy(dest,src):
    assert dest.dtype == src.dtype
    assert dest.shape == src.shape
    if len(dest.shape) == 1 :
        dest.shape=(dest.shape[0],1)
        src.shape=(src.shape[0],1)
        dest.strides=(dest.strides[0],0)
        src.strides=(src.strides[0],0)
    kernel.copy(queue, (src.dtype.itemsize,src.shape[0],src.shape[1]), None, dest.base_data, np.uint32(dest.offset), np.uint32(dest.strides[0]), np.uint32(dest.strides[1]), src.base_data, np.uint32(src.offset), np.uint32(src.strides[0]), np.uint32(src.strides[1]), np.uint32(src.dtype.itemsize))


a = cla.zeros(queue,(3,3),'int8')
b = cla.ones(queue,(2,2),'int8')

copy(a[0:2,0:2],b)
print(a)

在pyopencl邮件列表中,AndreasKlöckner给了我一个提示: pyopencl.array有一个未记录的函数,称为multiput() 语法如下:

cla.multi_put([arr],indices,out=[out])

“ arr”是源数组,“ out”是目标数组,“ indices”是int的一维数组(也在设备上),其中包含按行为主的线性元素索引。

例如,在我的第一篇文章中,将“ b”放入“ a”的索引为(0,1,3,4)。 您只需要以某种方式将索引放在一起即可使用multiput()而不是编写内核。 len(indices)当然必须等于b.size 还有一个take()multitake()函数,用于从数组中读取元素。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM