繁体   English   中英

在 Python 中计算伪逆 (pinv) 的最快方法

[英]Fastest way for computing pseudoinverse (pinv) in Python

我有一个循环,我在其中计算相当大的非稀疏矩阵(例如20000x800 )的几个伪逆。

由于我的代码大部分时间花在pinv上,我试图找到一种方法来加速计算。 我已经在使用多处理 ( joblib/loky ) 来运行多个进程,但这当然也会增加开销。 使用jit并没有多大帮助。

有没有更快的方法/更好的实现来使用任何 function 计算伪逆? 精度不是关键。

我目前的基准

import time
import numba
import numpy as np
from numpy.linalg import pinv as np_pinv
from scipy.linalg import pinv as scipy_pinv
from scipy.linalg import pinv2 as scipy_pinv2

@numba.njit
def np_jit_pinv(A):
  return np_pinv(A)

matrix = np.random.rand(20000, 800)
for pinv in [np_pinv, scipy_pinv, scipy_pinv2, np_jit_pinv]:
    start = time.time()
    pinv(matrix)
    print(f'{pinv.__module__ +"."+pinv.__name__} took {time.time()-start:.3f}')
numpy.linalg.pinv took 2.774
scipy.linalg.basic.pinv took 1.906
scipy.linalg.basic.pinv2 took 1.682
__main__.np_jit_pinv took 2.446

编辑:JAX 似乎快了 30%。 感人的。 感谢您让我知道@yuri-brigance。 对于 Windows,它在 WSL 下运行良好。

numpy.linalg.pinv took 2.774
scipy.linalg.basic.pinv took 1.906
scipy.linalg.basic.pinv2 took 1.682
__main__.np_jit_pinv took 2.446
jax._src.numpy.linalg.pinv took 0.995

尝试使用 JAX:

import jax.numpy as jnp

jnp.linalg.pinv(A)

似乎比常规numpy.linalg.pinv稍快。 在我的机器上你的基准看起来像这样:

jax._src.numpy.linalg.pinv took 3.127
numpy.linalg.pinv took 4.284

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM