[英]How to speed up python object functions using numba/CUDA?
我是 CUDA 的新手(大約一個小時前安裝了 numba)。 我想加速這個類中的函數。
def predict(self, X):
num_test = X.shape[0]
Ypred = np.zeros(num_test, dtype=self.ytr.dtype)
for i in range(num_test):
distances = np.sum(np.abs(self.Xtr-X[i, :]), axis=1)
min_index = np.argmin(distances)
Ypred[i] = self.ytr[min_index]
print(i)
return Ypred
X 是 float32 類型的二維數組,Ypred 是 int32 類型的數組。 我試圖通過在函數上方插入以下行來加速它。
@vectorize(['int32(float32)'], target='cuda')
這給了我一大堆錯誤,但其中重要的部分似乎是:
TypeError: Failed at nopython (analyzing bytecode)
Signature mismatch: 1 argument types given, but function takes 2 arguments
雖然我確切地知道錯誤說的是什么,但我不知道如何解決它。 那么......我如何使它工作? 提前致謝。
更新:
我應該在詢問之前進行適當的谷歌搜索(我確實進行了搜索,但我使用了“對象”而不是“類”這個詞,這沒有給我任何有用的結果)。 文檔對我幫助很大,但現在我遇到了這些錯誤,我不知道該怎么做。
numba.errors.LoweringError: Failed at nopython (nopython mode backend)
Can only insert float* at [4] in {i8*, i8*, i64, i64, float*, [2 x i64],
[2 x i64]}: got double*
File "main.py", line 40
[1] During: lowering "(self).Xtr = X" at D:/myStuff/DL/Week 3/1/main.py (40)
[2] During: resolving callee type:
BoundFunction((<class 'numba.types.misc.ClassInstanceType'>, 'train') f for instance.jitclass.NearestNeighbours#24f01184f58<Xtr:array(float32, 2d, A),ytr:array(int32, 1d, A)>)
[3] During: typing of call at <string> (3)
--%<-----------------------------------------------------------------
File "<string>", line 3
這是當前狀態下的整個類:
spec = [("Xtr", float32[:, :]), ("ytr", int32[:])]
@jitclass(spec)
class NearestNeighbours(object):
def __init__(self):
pass
def train(self, X, y):
self.Xtr = X #line 40
self.ytr = y
def predict(self, X):
num_test = X.shape[0]
Ypred = np.zeros(num_test, dtype=self.ytr.dtype)
for i in range(num_test):
distances = np.sum(np.abs(self.Xtr-X[i, :]), axis=1)
min_index = np.argmin(distances)
Ypred[i] = self.ytr[min_index]
print(i)
return Ypred
更新 2:放棄 jitting 類並嘗試將 predict 鏈接到其外部克隆。 使用空 jit 似乎有效,但鏈接到 cuda(為了速度)會導致各種奇怪的錯誤。 如果我以某種方式設法解決,我今天會休息一下並回答我自己的問題。 直到幾個小時前,我還認為 GPU 加速就像添加一個額外的庫或切換到不同的編譯器或其他東西一樣簡單......但是伙計......我不知道我會經歷如此坎坷的旅程.
據我所知,您的函數僅依賴於X
因此沒有理由將其作為類中的函數。 要么將其聲明為靜態@staticmethod
(鏈接),要么將其從類 - 范圍中取出。
Presto:僅剩 1 個函數參數。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.