簡體   English   中英

如何使用numpy.where()加速我的numpy循環

[英]How to speed up my numpy loop using numpy.where()

我最近編寫了一個關於有序logit模型的函數。
但是在運行大數據時需要花費很多時間。
所以我想重寫代碼並將numpy.where函數替換為if語句。
我的新代碼有一些問題,我不知道怎么做。
如果你知道,請幫助我。 非常感謝你!

這是我原來的功能。

import numpy as np
from scipy.stats import logistic

def func(y, X, thresholds):
    ll = 0.0
    for row in zip(y, X):
        if row[0] == 0:
           ll += logistic.logcdf(thresholds[0] - row[1])
        elif row[0] == len(thresholds):
           ll += logistic.logcdf(row[1] - thresholds[-1])
        else:
           for i in xrange(1, len(thresholds)):
               if row[0] == i:
                   diff_prob = logistic.cdf(thresholds[i] - row[1]) - logistic.cdf(thresholds[i - 1] - row[1])
                   if diff_prob <= 10 ** -5:
                       ll += np.log(10 ** -5)
                   else:
                       ll += np.log(diff_prob)
     return ll
y = np.array([0, 1, 2])
X = [2, 2, 2]
thresholds = np.array([2, 3])
print func(y, X, thresholds)

這是新的但不完美的代碼。

y = np.array([0, 1, 2])
X = [2, 2, 2]
thresholds = np.array([2, 3])
ll = np.where(y == 0, logistic.logcdf(thresholds[0] - X),
          np.where(y == len(thresholds), logistic.logcdf(X - thresholds[-1]),
                   np.log(logistic.cdf(thresholds[1] - X) - logistic.cdf(thresholds[0] - X))))
print ll.sum()

問題是我不知道如何重寫子循環( 對於i in xrange(1,len(thresholds)):)函數。

np.where一下如何使用np.where來實現它是一個X / Y問題

所以我將嘗試解釋如何優化此功能。

我的第一直覺是擺脫for循環,這無論如何都是痛點:

import numpy as np
from scipy.stats import logistic

def func1(y, X, thresholds):
    ll = 0.0
    for row in zip(y, X):
        if row[0] == 0:
            ll += logistic.logcdf(thresholds[0] - row[1])
        elif row[0] == len(thresholds):
            ll += logistic.logcdf(row[1] - thresholds[-1])
        else:
            diff_prob = logistic.cdf(thresholds[row[0]] - row[1]) - \
                         logistic.cdf(thresholds[row[0] - 1] - row[1])
            diff_prob = 10 ** -5 if diff_prob < 10 ** -5 else diff_prob
            ll += np.log(diff_prob)
    return ll

y = np.array([0, 1, 2])
X = [2, 2, 2]
thresholds = np.array([2, 3])
print(func1(y, X, thresholds))

我剛剛用row[0]替換了i ,而沒有改變循環的語義。 所以這是一個少循環。

現在我希望if-else的不同分支中的語句形式是相同的。 為此:

import numpy as np
from scipy.stats import logistic

def func2(y, X, thresholds):
    ll = 0.0

    for row in zip(y, X):
        if row[0] == 0:
            ll += logistic.logcdf(thresholds[0] - row[1])
        elif row[0] == len(thresholds):
            ll += logistic.logcdf(row[1] - thresholds[-1])
        else:
            ll += np.log(
                np.maximum(
                    10 ** -5, 
                    logistic.cdf(thresholds[row[0]] - row[1]) -
                     logistic.cdf(thresholds[row[0] - 1] - row[1])
                )
            )
    return ll

y = np.array([0, 1, 2])
X = [2, 2, 2]
thresholds = np.array([2, 3])
print(func2(y, X, thresholds))

現在每個分支中的表達式的形式為ll += expr

在這種情況下,優化可以采用幾種不同的路徑。 您可以嘗試通過將其作為一種理解來優化循環,但我懷疑它不會給你太多的速度提升。

另一條路徑是將if條件拉出循環。 這就是你對np.where的意圖:

import numpy as np
from scipy.stats import logistic

def func3(y, X, thresholds):
    y_0 = y == 0
    y_end = y == len(thresholds)
    y_rest = ~(y_0 | y_end)

    ll_1 = logistic.logcdf(thresholds[0] - X[ y_0 ])
    ll_2 = logistic.logcdf(X[ y_end ] - thresholds[-1])
    ll_3 = np.log(
        np.maximum(
            10 ** -5, 
            logistic.cdf(thresholds[y[ y_rest ]] - X[ y_rest ]) -
              logistic.cdf(thresholds[ y[y_rest] - 1 ] - X[ y_rest])
        )
    )
    return np.sum(ll_1) + np.sum(ll_2) + np.sum(ll_3)

y = np.array([0, 1, 2])
X = np.array([2, 2, 2])
thresholds = np.array([2, 3])
print(func3(y, X, thresholds))

請注意,我將X轉換為np.array ,以便能夠使用花式索引。

在這一點上,我打賭它對我的目的足夠快。 但是,根據您的要求,您可以提前或超出此點。


在我的計算機上,我得到以下結果:

y = np.random.random_integers(0, 10, size=(10000,))
X = np.random.random_integers(0, 10, size=(10000,))
thresholds = np.cumsum(np.random.rand(10))

%timeit func(y, X, thresholds) # Original
1 loops, best of 3: 1.51 s per loop

%timeit func1(y, X, thresholds) # Removed for-loop
1 loops, best of 3: 1.46 s per loop

%timeit func2(y, X, thresholds) # Standardized if statements
1 loops, best of 3: 1.5 s per loop

%timeit func3(y, X, thresholds) # Vectorized ~ 500x improvement
100 loops, best of 3: 2.74 ms per loop

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM