簡體   English   中英

脾氣暴躁的“:”運營商廣播問題

[英]Numpy “:” operator broadcasting issues

在下面的代碼中,我編寫了2種方法(理論上,在我看來)應該做同樣的事情。 不幸的是,他們沒有這樣做,我無法找出他們為什么不根據numpy文檔做同樣的事情。

import numpy as np


dW = np.zeros((20, 10))
y = [1 for _ in range(100)]
X =  np.ones((100, 20))

# ===================
# Method 1  (works!)
# ===================
for i in range(len(y)):
  dW[:, y[i]] -=  X[i]


# ===================
# Method 2 (does not work)
# ===================
dW[:, y] -=  X.T

如前所述,原則上,由於緩沖在NumPy中的工作方式,因此您無法在同一操作中對同一元素進行多次操作。 為此,提供了at函數,該函數可用於幾乎所有標准NumPy函數( addsubtract等)。 對於您的情況,您可以執行以下操作:

import numpy as np

dW = np.zeros((20, 10))
y = [1 for _ in range(100)]
X =  np.ones((100, 20))
# at modifies in place dW, does not return a new array
np.subtract.at(dW, (slice(None), y), X.T)

這是問題的逐欄形式。

那里的答案可以調整為按列工作,如下所示:

方法1: np.<ufunc>.at

>>> np.subtract.at(dW, (slice(None), y), X.T)

方法2: np.bincount

>>> m, n = dW.shape
>>> dW -= np.bincount(np.add.outer(np.arange(m) * n, y).ravel(), (X.T).ravel(), dW.size).reshape(m, n)

請注意, bincount基於bincount的解決方案涉及更多步驟,但其速度卻快了約6倍。

>>> from timeit import repeat
>>> kwds = dict(globals=globals(), number=5000)
>>>
>>> repeat('np.subtract.at(dW, (slice(None), y), X.T); np.add.at(dW, (slice(None), y), X.T)', **kwds)
[1.590626839082688, 1.5769231889862567, 1.5802007300080732]
>>> repeat('_= dW; _ -= np.bincount(np.add.outer(np.arange(m) * n, y).ravel(), (X.T).ravel(), dW.size).reshape(m, n); _ += np.bincount(np.add.outer(np.arange(m) * n, y).ravel(), (X.T).ravel(), dW.size).reshape(m, n)', **kwds)
[0.2582490430213511, 0.25572817400097847, 0.25478115503210574]

選項1:

for i in range(len(y)):
  dW[:, y[i]] -=  X[i]

之所以有效,是因為您正在遍歷和更新上次更新的值。

選項2:

dW[:, [1,1,1,1,....1,1,1]] -=  [[1,1,1,1...1],
                                [1,1,1,1...1],
                                .
                                .
                                [1,1,1,1...1]]

它不起作用,因為同時並行而不是串行地對第一索引進行更新。 最初都是0,所以減去結果為-1。

我找到了解決該問題的第三種方法。 法線矩陣乘法:

ind = np.zeros((X.shape[0],dW.shape[1]))
ind[range(X.shape[0]),y] = -1
dW = X.T.dot(ind)

我在一些神經網絡數據上使用上述建議的方法進行了一些實驗。 在我的示例中X.shape = (500,3073)W.shape = (3073,10)ind.shape = (500,10)

減法版本大約需要0.2秒(最慢)。 矩陣乘法方法為0.01 s(最快)。 正常循環0.015,然后bincount方法0.04 s。 請注意,在問題y是1的向量。 這不是我的情況。 只有一個的情況可以用一個簡單的和解。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM