簡體   English   中英

Cython:矩陣乘法

[英]Cython: matrix multiplication

我已經對以下使用 numpy 的矩陣乘法的文件進行了 cythonized:

def cell(float[:, ::1] a, float[:, ::1] b):
  c = a @ b
  return c

但是,當我用以下方式調用它時:

from matmul import cell
import numpy as np


a = np.zeros((1, 64), dtype=np.float32)
b = np.zeros((64, 64), dtype=np.float32)
c = cell(a, b)

我收到以下錯誤:

類型錯誤:@ 不支持的操作數類型:_memoryviewslice 和 _memoryviewslice

如何使用 Cython 執行矩陣乘法?

上下文:function“單元”是我編寫的代碼的一部分,它通過 LSTM 網絡執行預測(我手動編寫了它,沒有使用 PyTorch 或 Tensorflow)。 我需要加速代碼才能實時使用網絡。

如果這就是你所做的一切,那么為cell的參數添加類型實際上是沒有意義的——你所做的只是無緣無故地添加昂貴的類型檢查。 Cython 無法有效利用這些類型。 只需不ab即可。

如果您確實需要使用 Numpy 整個數組操作來修復 memoryviews 操作,最簡單的解決方案是調用np.asarray

def cell(float[:, ::1] a, float[:, ::1] b):
  c = np.asarray(a) @ np.asarray(b)
  return c

您在這里沒有從 Cython 獲得任何好處 - 它只是調用 Numpy 矩陣乘法代碼。 因此,僅在需要將其與您確實受益於 Cython 的某些操作混合時才執行此操作。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM