繁体   English   中英

在涉及线性代数的函数上应用 Numpy 广播

[英]Applying Numpy broadcasting on function involving linear algebra

我想在涉及线性代数(没有分母部分的双变量高斯分布)的数学函数上使用numpy广播功能。 我的代码的最小的、可重现的示例是这样的:

我有以下功能

import numpy as np
def  gaussian(x):
    mu = np.array([[2],
                   [2]])
    sigma = np.array([[10, 0],
                      [0, 10]])
    xm = x - mu
    result = np.exp((-1/2) * xm.T @ np.linalg.inv(sigma) @ xm)
    return result

该函数假定x是一个 2x1 数组。 我的目标是使用该函数生成一个二维数组,其中各个元素是该函数的乘积。 我按如下方式应用此功能:

x, y = np.arange(5), np.arange(5)
xLen, yLen = len(x), len(y)
z = np.zeros((yLen, xLen))

for y_index in range(yLen):
        for x_index in range(xLen):
            element = np.array([[x[x_index]],
                                [y[y_index]]])
            result = gaussian(element)
            z[y_index][x_index] = result

这可行,但如您所见,我使用两个 for 循环进行索引。 我知道这是不好的做法,并且在使用更大的数组时速度非常慢。 我想用numpy广播功能解决这个问题。 我尝试了以下代码:

X, Y = np.meshgrid(x, y, indexing= 'xy')
element = np.array([[X],
                    [Y]])
Z = gaussian(element)

但是我收到了这个错误: ValueError: operands could not be broadcast together with shapes (2,1,5,5) (2,1)函数的xm = x - mu行的形状 (2,1,5,5) (2,1) 一起广播。 我在一定程度上理解这个错误。

此外,即使我解决了这个问题,我也会收到另一个错误: ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 5 is different from 2)result = np.exp((-1/2) * xm.T @ np.linalg.inv(sigma) @ xm)函数的行。 再次,我明白为什么。 xm将不再是 2x1 数组,并且将其与sigma (即 2x2)相乘将不起作用。

有人对如何修改我的功能以使广播实现工作有建议吗?

以下可能有效。 有两点需要注意:

  • 我正在使用np.einsum进行向量矩阵向量乘法。 可能有更快的方法,但这可以很好地处理要广播的其他维度。
  • 根据我的经验,对于较大的数组,使用 3 维以上的广播,事情可能不会比简单的嵌套循环快。 我还没有深入研究:也许计算是在错误的维度上进行的(列与行问题),这会减慢速度。 因此,也许通过调整或调整维度顺序,可以加快速度

设置代码

nx, ny = 5, 5
x, y = np.arange(nx), np.arange(ny)
X, Y = np.meshgrid(x, y, indexing= 'xy')
element = np.array([[X],
                    [Y]])
# Stack X and Y into a nx x ny x 2 array
XY = np.dstack([X, Y])

新功能

def  gaussian(x):
    # Note that I have removed the extra dimension: 
    # mu is a simple array of shape (2,)
    # This is no problem, since we're using einsum
    # for the matrix multiplication
    mu = np.array([2, 2])
    sigma = np.array([[10, 0],
                      [0, 10]])
    # Broadcast xm to x's shape: (nx, ny, 2)
    xm = x - mu[..., :]
    invsigma = np.linalg.inv(sigma)
    # Compute the (double) matrix multiplication
    # Leave the first two dimension (ab) alone
    # The other dimensions will sum up to a single scalar
    # and thus only the ab dimensions are there in the output
    alpha = np.einsum('abi,abj,ji->ab', xm, xm, invsigma)
    result = np.exp((-1/2) * alpha)
    # The shape of result is (nx, ny)
    return result

然后调用:

gaussian(XY)

显然,请仔细检查。 我做了一个简短的检查,这似乎是正确的,但转录错误可能例如交换了尺寸。

所以 (2,1) 输入返回 (1,1) 结果:

In [83]: gaussian(np.ones((2,1)))
Out[83]: array([[0.90483742]])

添加一些主要维度:

In [84]: gaussian(np.ones((3,4,2,1)))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [84], in <cell line: 1>()
----> 1 gaussian(np.ones((3,4,2,1)))

Input In [80], in gaussian(x)
      4 sigma = np.array([[10, 0],
      5                   [0, 10]])
      6 xm = x - mu
----> 7 result = np.exp((-1/2) * xm.T @ np.linalg.inv(sigma) @ xm)
      8 return result

ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 3)

x-mu有效,因为 (3,4,2,1) 与 (2,1) 一起广播

错误发生在(-1/2) * xm.T @ np.linalg.inv(sigma)

np.linalg.inv(sigma)是 (2,2)

xm(3,4,2,1) ,所以它的转置是 (1,2,4,3)。

相反,如果数组是 (3,4,1,2) @ (2,2) @ (3,4,2,1),则结果应该是 (3,4,1,1)。

所以让我们改进转置:

def  gaussian(x):
    mu = np.array([[2],
                   [2]])
    sigma = np.array([[10, 0],
                      [0, 10]])
    xm = x - mu
    xmt =xm.swapaxes(-2,-1)
    result = np.exp((-1/2) * xmt @ np.linalg.inv(sigma) @ xm)
    return result

现在它适用于原始 (2,1) 和任何其他 (n,m,2,1) 形状:

In [87]: gaussian(np.ones((3,4,2,1))).shape
Out[87]: (3, 4, 1, 1)

In [88]: gaussian(np.ones((2,1))).shape
Out[88]: (1, 1)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM