简体   繁体   中英

Cython: matrix multiplication

I have cythonized the following file that uses numpy's matrix multiplication:

def cell(float[:, ::1] a, float[:, ::1] b):
  c = a @ b
  return c

However, when I call it with:

from matmul import cell
import numpy as np


a = np.zeros((1, 64), dtype=np.float32)
b = np.zeros((64, 64), dtype=np.float32)
c = cell(a, b)

I get the following error:

TypeError: unsupported operand type(s) for @: _memoryviewslice and _memoryviewslice

How can I perform matrix multiplication with Cython?

Context: the function "cell" is part of a code I wrote that performs a prediction by an LSTM network (I wrote it manually, without using PyTorch or Tensorflow, just NumPy). I need to speed up the code to be able to use the network in real-time.

If that's all you're doing there's literally no point in adding the types for the argument of cell - all you're doing is adding expensive type-checks for no reason. Cython can't make useful use of these types. Just leave a and b untyped.

If you do actually need to fix memoryviews operations with Numpy whole-array operations the easiest solution is to call np.asarray

def cell(float[:, ::1] a, float[:, ::1] b):
  c = np.asarray(a) @ np.asarray(b)
  return c

You aren't getting any benefit from Cython here - it's just calling into the Numpy matrix multiply code. So only do this where you need to mix it with some operations where you do benefit from Cython.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM