简体   繁体   English

使用NumPy索引数组切片Python列表 - 任何快速方式?

[英]Slicing a Python list with a NumPy array of indices — any fast way?

I have a regular list called a , and a NumPy array of indices b . 我有一个名为a的常规list和一个NumPy索引数组b
(No, it is not possible for me to convert a to a NumPy array.) (不,我不可能将a转换为NumPy数组。)

Is there any way for me to the same effect as " a[b] " efficiently? 对我有什么办法和“ a[b] ”有效的效果一样吗? To be clear, this implies that I don't want to extract every individual int in b due to its performance implications. 需要说明的是,这意味着我不想因为其性能影响而提取b每个int

(Yes, this is a bottleneck in my code. That's why I'm using NumPy arrays to begin with.) (是的,这是我的代码中的瓶颈。这就是我开始使用NumPy数组的原因。)

a = list(range(1000000))
b = np.random.randint(0, len(a), 10000)

%timeit np.array(a)[b]
10 loops, best of 3: 84.8 ms per loop

%timeit [a[x] for x in b]
100 loops, best of 3: 2.93 ms per loop

%timeit operator.itemgetter(*b)(a)
1000 loops, best of 3: 1.86 ms per loop

%timeit np.take(a, b)
10 loops, best of 3: 91.3 ms per loop

I had high hopes for numpy.take() but it is far from optimal. 我对numpy.take()寄予厚望,但它远非最佳。 I tried some Numba solutions as well, and they yielded similar times--around 92 ms. 我也尝试了一些Numba解决方案,他们产生了类似的时间 - 大约92毫秒。

So a simple list comprehension is not far from the best here, but operator.itemgetter() wins, at least for input sizes at these orders of magnitude. 因此,简单的列表理解与此处的最佳匹配并不相同,但是operator.itemgetter()获胜,至少对于这些数量级的输入大小而言。

Write a cython function: 写一个cython函数:

import cython
from cpython cimport PyList_New, PyList_SET_ITEM, Py_INCREF

@cython.wraparound(False)
@cython.boundscheck(False)
def take(list alist, Py_ssize_t[:] arr):
    cdef:
        Py_ssize_t i, idx, n = arr.shape[0]
        list res = PyList_New(n)
        object obj

    for i in range(n):
        idx = arr[i]
        obj = alist[idx]
        PyList_SET_ITEM(res, i, alist[idx])
        Py_INCREF(obj)

    return res

The result of %timeit: %timeit的结果:

import numpy as np

al= list(range(10000))
aa = np.array(al)

ba = np.random.randint(0, len(a), 10000)
bl = ba.tolist()

%timeit [al[i] for i in bl]
%timeit np.take(aa, ba)
%timeit take(al, ba)

1000 loops, best of 3: 1.68 ms per loop
10000 loops, best of 3: 51.4 µs per loop
1000 loops, best of 3: 254 µs per loop

numpy.take() is the fastest if both of the arguments are ndarray object. 如果两个参数都是ndarray对象,则numpy.take()是最快的。 The cython version is 5x faster than list comprehension. cython版本比列表理解快5倍。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM