我正在尝试通过在Cython中重写以下代码来优化以下代码:它仅使用低维但相对较长的numpy数组,在其列中查找0值,并将它们标记为-1。 代码是:

import numpy as np

def get_data():
    data = np.array([[1,5,1]] * 5000 + [[1,0,5]] * 5000 + [[0,0,0]] * 5000)
    return data

def get_cols(K):
    cols = np.array([2] * K)
    return cols

def test_nonzero(data):
    K = len(data)
    result = np.array([1] * K)
    # Index into columns of data
    cols = get_cols(K)
    # Mark zero points with -1
    idx = np.nonzero(data[np.arange(K), cols] == 0)[0]
    result[idx] = -1

import time
t_start = time.time()
data = get_data()
for n in range(5000):
    test_nonzero(data)
t_end = time.time()
print (t_end - t_start)

data就是数据。 cols是用于查找非零值的数据列的数组(为简单起见,我将它们都设置为同一列)。 目标是计算一个numpy数组result ,对于关注列非零的每一行,其值为1;对于对应关注列为零的行,其值为-1。

在不太大的15,000行乘3列的数组上运行此功能5000次需要大约20秒。 有什么办法可以加快速度吗? 看来,大部分工作都在寻找非零元素并使用索引进行检索(对nonzero的调用以及对其索引的后续使用。)是否可以对此进行优化或做到最好? Cython实施如何在此方面加快速度?

===============>>#1 票数:2

cols = np.array([2] * K)

这真的很慢。 那将创建一个非常大的python列表,然后将其转换为numpy数组。 相反,请执行以下操作:

cols = np.ones(K, int)*2

那会更快

result = np.array([1] * K)

在这里您应该做:

result = np.ones(K, int)

那将直接产生numpy数组。

idx = np.nonzero(data[np.arange(K), cols] == 0)[0]
result[idx] = -1

cols是一个数组,但是您只能传递2。此外,使用非零值会增加一个额外的步骤。

idx = data[np.arange(K), 2] == 0
result[idx] = -1

应该有同样的效果。

  ask by translate from so

未解决问题?本站智能推荐: