簡體   English   中英

為什么 np.linalg.norm(...,axis=1) 比寫出向量范數的公式慢?

[英]Why is np.linalg.norm(..., axis=1) slower than writing out the formula for vector norms?

要將矩陣X的行歸一化為單位長度,我通常使用:

X /= np.linalg.norm(X, axis=1, keepdims=True)

嘗試為算法優化此操作時,我很驚訝地發現在我的機器上寫出歸一化的速度大約快 40%:

X /= np.sqrt(X[:,0]**2+X[:,1]**2+X[:,2]**2)[:,np.newaxis]
X /= np.sqrt(sum(X[:,i]**2 for i in range(X.shape[1])))[:,np.newaxis]

怎么會? np.linalg.norm()的性能損失在np.linalg.norm()

import numpy as np
X = np.random.randn(10000,3)

%timeit X/np.linalg.norm(X,axis=1, keepdims=True)
# 276 µs ± 4.55 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit X/np.sqrt(X[:,0]**2+X[:,1]**2+X[:,2]**2)[:,np.newaxis]
# 169 µs ± 1.38 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit X/np.sqrt(sum(X[:,i]**2 for i in range(X.shape[1])))[:,np.newaxis]
# 185 µs ± 4.17 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

我在支持(2) python3.9 + numpy v1.19.3的 MacbookPro 2015 上觀察到(1) python3.6 + numpy v1.17.2(2) python3.9 + numpy v1.19.3

我不認為這是這篇文章的副本,它解決了矩陣范數,而這個是關於向量的 L2 范數。

row-wise L2-norm 的源代碼歸結為以下幾行代碼:

def norm(x, keepdims=False):
    x = np.asarray(x)
    s = x**2
    return np.sqrt(s.sum(axis=(1,), keepdims=keepdims))

簡化代碼假設x實值,並利用np.add.reduce(s, ...)等價於s.sum(...)的事實。

因此,OP 問題與詢問為什么np.sum(x,axis=1)sum(x[:,i] for i in range(x.shape[1]))

%timeit X.sum(axis=1, keepdims=False)
# 131 µs ± 1.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit sum(X[:,i] for i in range(X.shape[1]))
# 36.7 µs ± 91.2 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

這個問題已經在這里回答了。 簡而言之,減少( .sum(axis=1) )帶來的開銷成本通常在浮點精度和速度(例如緩存機制,並行性)方面得到回報,但在減少的特殊情況下不會僅超過三列。 在這種情況下,與實際計算相比,開銷相對較大。

如果X有更多列,情況就會改變。 numpy-boosted 標准化現在比使用 python for 循環的減少快得多:

X = np.random.randn(10000,100)
%timeit X/np.linalg.norm(X,axis=1, keepdims=True)
# 3.36 ms ± 132 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit X/np.sqrt(sum(X[:,i]**2 for i in range(X.shape[1])))[:,np.newaxis]
# 5.92 ms ± 168 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

另一個相關的 SO 線程在這里找到: numpy ufuncs vs. for loop

問題仍然是為什么 numpy 沒有明確處理常見的特殊簡化情況(例如對具有低軸維數的矩陣的列或行求和)。 可能是因為這種優化的效果往往強烈依賴於目標機器,並大大增加了代碼的復雜性。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM