簡體   English   中英

Numpy sum 越多,輸入越快

[英]Numpy sum gets faster with more entries

我已經測量了 numpy 的 sum 函數添加一定數量的值所需的時間。 下面的曲線表示 numpy 的 sum 函數(綠色)和標准 python for 循環(藍色)將每個元素的所有元素相加所需的時間。 因此,例如,當在 10 1 個元素的數組上使用 numpy 的 sum 函數時,每個元素平均需要 2 -19秒來處理。

這是代碼:

random_vector = np.random.random(10**7)

def for_sum(vector, n):
    sum = 0   

    start_time = time.perf_counter()
    for i in range(n):
        sum += vector[i]

    return time.perf_counter() - start_time

def numpy_sum(vector, n):
    new_vector = vector[:n]

    start_time = time.perf_counter()
    np.sum(new_vector)

    return time.perf_counter() - start_time

# determines the number of elements we should sum
spaced_values = np.logspace(1, 7, num=30, dtype=int)

# Measure time for for loops, per entry
for_sum_times_per_entry = np.zeros(0)
for i in spaced_values:
    for_sum_times_per_entry = np.append(for_sum_times_per_entry, for_sum(random_vector, i)/i)

# Measure time for numpy sum function, per entry
numpy_sum_times_per_entry = np.zeros(0)
for i in spaced_values:
   numpy_sum_times_per_entry = np.append(numpy_sum_times_per_entry, numpy_sum(random_vector, i)/i)

# Plot the amount of time required to sum each entries
plt.loglog(spaced_values, for_sum_times_per_entry, basex=10, basey=2)
plt.loglog(spaced_values, numpy_sum_times_per_entry, basex=10, basey=2)
plt.xlabel("Number of values summed")
plt.ylabel("Single entry computing time (s)")
plt.show()    

numpy sum 和 for 循環

numpy sum 曲線顯示,隨着要求和的元素總數的增加,1 個元素所需的處理時間越來越少。

這是因為什么? 我的觀點是 numpy 的 sum 函數有一定的開銷,會增加處理時間。 這種開銷需要固定的時間,因此隨着我們添加元素,它變得越來越不重要。

Numpy 在numpy.sum使用 BLAS,因此對於numpy.sum您可以應用低級邏輯:如果整個陣列適合 CPU 緩存,您將看到平面配置文件。 我機器上的 L1 緩存是 64KB,只要你沒有超過 64KB / 64 位/浮點 ~ 8000 個浮點條目, numpy.sum的計算時間numpy.sum保持不變:

在此處輸入圖片說明

這個圖是用

import math
import numpy
import perfplot


def for_sum(a):
    sum = 0.0
    for i in range(len(a)):
        sum += a[i]
    return sum


perfplot.show(
    setup=numpy.random.rand,
    kernels=[for_sum, numpy.sum, sum, math.fsum],
    labels=["for-sum", "numpy.sum", "sum", "math.fsum"],
    n_range=[2 ** k for k in range(18)],
    xlabel="len(a)",
    logx=True,
    logy=True,
)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM