[英]Vectorizing the conversion of columns in a 2D numpy array to byte strings
背景
我有一个 2D numpy 数组,它表示大量的网格坐标向量,每个坐标向量都需要转换为字节字符串,以便它们可以转换为 python 集。
这个字节字符串转换过程是我代码运行时的真正瓶颈,所以我正在寻找加速它的方法。
示例代码
from numpy import int16
from numpy.random import randint
# make an array of coordinate vectors full of random ints
A = randint(-100,100,size = (10000,5), dtype=int16)
# pull each vector out of the array using iteration and convert to byte string
A = [v.tobytes() for v in A]
# build a set using the byte strings
S = set(A)
计时测试
使用timeit
测试我们得到的当前代码
setup = 'from numpy import int16; from numpy.random import randint; A = randint(-100,100,size = (10000,5), dtype=int16)'
code = 'S = set([v.tobytes() for v in A])'
t = timeit(code, setup = setup, number=500)
print(t)
>>> 1.136594653999964
转换后构建集合小于总计算时间的 15%:
setup = 'from numpy import int16; from numpy.random import randint; A = randint(-100,100,size = (10000,5), dtype=int16); A = [v.tobytes() for v in A]'
code = 'S = set(A)'
t = timeit(code, setup = setup, number=500)
print(t)
>>> 0.15499859599980482
还值得注意的是,将整数的大小加倍到 32 位对运行时间的影响很小:
setup = 'from numpy import int32; from numpy.random import randint; A = randint(-100,100,size = (10000,5), dtype=int32)'
code = 'S = set([v.tobytes() for v in A])'
t = timeit(code, setup = setup, number=500)
print(t)
>>> 1.1422132620000411
这让我怀疑这里的大部分时间都被迭代或对tostring()
的函数调用的开销所消耗。
如果是这种情况,我想知道是否有一种避免迭代的矢量化方法?
谢谢!
这是使用np.frombuffer
的矢量化方法 -
# a : Input array of coordinates with int16 dtype
S = set(np.frombuffer(a,dtype='S'+str(a.shape[1]*2)))
给定样本数据集的时间 -
In [83]: np.random.seed(0)
...: a = randint(-100,100,size = (10000,5), dtype=int16)
In [128]: %timeit set([v.tobytes() for v in a])
2.71 ms ± 99.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [129]: %timeit set(np.frombuffer(a,dtype='S'+str(a.shape[1]*2)))
933 µs ± 4.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [130]: out1 = set([v.tobytes() for v in a])
In [131]: out2 = set(np.frombuffer(a,dtype='S'+str(a.shape[1]*2)))
In [132]: (np.sort(list(out1))==np.sort(list(out2))).all()
Out[132]: True
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.