簡體   English   中英

在涉及線性代數的函數上應用 Numpy 廣播

[英]Applying Numpy broadcasting on function involving linear algebra

我想在涉及線性代數(沒有分母部分的雙變量高斯分布)的數學函數上使用numpy廣播功能。 我的代碼的最小的、可重現的示例是這樣的:

我有以下功能

import numpy as np
def  gaussian(x):
    mu = np.array([[2],
                   [2]])
    sigma = np.array([[10, 0],
                      [0, 10]])
    xm = x - mu
    result = np.exp((-1/2) * xm.T @ np.linalg.inv(sigma) @ xm)
    return result

該函數假定x是一個 2x1 數組。 我的目標是使用該函數生成一個二維數組,其中各個元素是該函數的乘積。 我按如下方式應用此功能:

x, y = np.arange(5), np.arange(5)
xLen, yLen = len(x), len(y)
z = np.zeros((yLen, xLen))

for y_index in range(yLen):
        for x_index in range(xLen):
            element = np.array([[x[x_index]],
                                [y[y_index]]])
            result = gaussian(element)
            z[y_index][x_index] = result

這可行,但如您所見,我使用兩個 for 循環進行索引。 我知道這是不好的做法,並且在使用更大的數組時速度非常慢。 我想用numpy廣播功能解決這個問題。 我嘗試了以下代碼:

X, Y = np.meshgrid(x, y, indexing= 'xy')
element = np.array([[X],
                    [Y]])
Z = gaussian(element)

但是我收到了這個錯誤: ValueError: operands could not be broadcast together with shapes (2,1,5,5) (2,1)函數的xm = x - mu行的形狀 (2,1,5,5) (2,1) 一起廣播。 我在一定程度上理解這個錯誤。

此外,即使我解決了這個問題,我也會收到另一個錯誤: ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 5 is different from 2)result = np.exp((-1/2) * xm.T @ np.linalg.inv(sigma) @ xm)函數的行。 再次,我明白為什么。 xm將不再是 2x1 數組,並且將其與sigma (即 2x2)相乘將不起作用。

有人對如何修改我的功能以使廣播實現工作有建議嗎?

以下可能有效。 有兩點需要注意:

  • 我正在使用np.einsum進行向量矩陣向量乘法。 可能有更快的方法,但這可以很好地處理要廣播的其他維度。
  • 根據我的經驗,對於較大的數組,使用 3 維以上的廣播,事情可能不會比簡單的嵌套循環快。 我還沒有深入研究:也許計算是在錯誤的維度上進行的(列與行問題),這會減慢速度。 因此,也許通過調整或調整維度順序,可以加快速度

設置代碼

nx, ny = 5, 5
x, y = np.arange(nx), np.arange(ny)
X, Y = np.meshgrid(x, y, indexing= 'xy')
element = np.array([[X],
                    [Y]])
# Stack X and Y into a nx x ny x 2 array
XY = np.dstack([X, Y])

新功能

def  gaussian(x):
    # Note that I have removed the extra dimension: 
    # mu is a simple array of shape (2,)
    # This is no problem, since we're using einsum
    # for the matrix multiplication
    mu = np.array([2, 2])
    sigma = np.array([[10, 0],
                      [0, 10]])
    # Broadcast xm to x's shape: (nx, ny, 2)
    xm = x - mu[..., :]
    invsigma = np.linalg.inv(sigma)
    # Compute the (double) matrix multiplication
    # Leave the first two dimension (ab) alone
    # The other dimensions will sum up to a single scalar
    # and thus only the ab dimensions are there in the output
    alpha = np.einsum('abi,abj,ji->ab', xm, xm, invsigma)
    result = np.exp((-1/2) * alpha)
    # The shape of result is (nx, ny)
    return result

然后調用:

gaussian(XY)

顯然,請仔細檢查。 我做了一個簡短的檢查,這似乎是正確的,但轉錄錯誤可能例如交換了尺寸。

所以 (2,1) 輸入返回 (1,1) 結果:

In [83]: gaussian(np.ones((2,1)))
Out[83]: array([[0.90483742]])

添加一些主要維度:

In [84]: gaussian(np.ones((3,4,2,1)))
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [84], in <cell line: 1>()
----> 1 gaussian(np.ones((3,4,2,1)))

Input In [80], in gaussian(x)
      4 sigma = np.array([[10, 0],
      5                   [0, 10]])
      6 xm = x - mu
----> 7 result = np.exp((-1/2) * xm.T @ np.linalg.inv(sigma) @ xm)
      8 return result

ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 2 is different from 3)

x-mu有效,因為 (3,4,2,1) 與 (2,1) 一起廣播

錯誤發生在(-1/2) * xm.T @ np.linalg.inv(sigma)

np.linalg.inv(sigma)是 (2,2)

xm(3,4,2,1) ,所以它的轉置是 (1,2,4,3)。

相反,如果數組是 (3,4,1,2) @ (2,2) @ (3,4,2,1),則結果應該是 (3,4,1,1)。

所以讓我們改進轉置:

def  gaussian(x):
    mu = np.array([[2],
                   [2]])
    sigma = np.array([[10, 0],
                      [0, 10]])
    xm = x - mu
    xmt =xm.swapaxes(-2,-1)
    result = np.exp((-1/2) * xmt @ np.linalg.inv(sigma) @ xm)
    return result

現在它適用於原始 (2,1) 和任何其他 (n,m,2,1) 形狀:

In [87]: gaussian(np.ones((3,4,2,1))).shape
Out[87]: (3, 4, 1, 1)

In [88]: gaussian(np.ones((2,1))).shape
Out[88]: (1, 1)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM