簡體   English   中英

Numpy廣播執行歐式距離矢量化

[英]Numpy Broadcast to perform euclidean distance vectorized

我有2 x 4和3 x 4的矩陣。我想找到各行之間的歐幾里得距離,並在最后得到2 x 3的矩陣。 這是一個帶for循環的代碼,它針對所有b行向量計算a中每個行向量的歐式距離。 在不使用for循環的情況下該如何做?

 import numpy as np
a = np.array([[1,1,1,1],[2,2,2,2]])
b = np.array([[1,2,3,4],[1,1,1,1],[1,2,1,9]])
dists = np.zeros((2, 3))
for i in range(2):
      dists[i] = np.sqrt(np.sum(np.square(a[i] - b), axis=1))

以下是原始輸入變量:

A = np.array([[1,1,1,1],[2,2,2,2]])
B = np.array([[1,2,3,4],[1,1,1,1],[1,2,1,9]])
A
# array([[1, 1, 1, 1],
#        [2, 2, 2, 2]])
B
# array([[1, 2, 3, 4],
#        [1, 1, 1, 1],
#        [1, 2, 1, 9]])

A是2x4陣列。 B是3x4陣列。

我們要在一個完全矢量化的運算中計算歐幾里得距離矩陣運算,其中dist[i,j]包含A中第ith個實例與B中第j個實例之間的距離。因此,在此示例中dist為2x3。

距離

在此處輸入圖片說明

表面上可以用numpy編寫為

dist = np.sqrt(np.sum(np.square(A-B))) # DOES NOT WORK
# Traceback (most recent call last):
#   File "<stdin>", line 1, in <module>
# ValueError: operands could not be broadcast together with shapes (2,4) (3,4)

然而,如上所述,問題在於逐元素減法運算AB涉及不兼容的陣列尺寸,特別是在第一維度上為2和3。

A has dimensions 2 x 4
B has dimensions 3 x 4

為了進行逐元素減法,我們必須填充A或B以滿足numpy的廣播規則。 我將選擇以額外的尺寸填充A,以使其變為2 x 1 x 4,從而使陣列的尺寸可以排成一排以進行廣播。 有關numpy廣播的更多信息,請參見scipy手冊中教程以及本教程中的最后一個示例。

您可以使用np.newaxis值或np.reshape命令執行填充。 我在下面都顯示:

# First approach is to add the extra dimension to A with np.newaxis
A[:,np.newaxis,:] has dimensions 2 x 1 x 4
B has dimensions                     3 x 4

# Second approach is to reshape A with np.reshape
np.reshape(A, (2,1,4)) has dimensions 2 x 1 x 4
B has dimensions                          3 x 4

如您所見,使用這兩種方法都會使尺寸對齊。 我將在np.newaxis使用第一種方法。 所以現在,這將可以創建AB(一個2x3x4數組):

diff = A[:,np.newaxis,:] - B
# Alternative approach:
# diff = np.reshape(A, (2,1,4)) - B
diff.shape
# (2, 3, 4)

現在,我們可以將該差異表達式放入dist方程語句中,以獲得最終結果:

dist = np.sqrt(np.sum(np.square(A[:,np.newaxis,:] - B), axis=2))
dist
# array([[ 3.74165739,  0.        ,  8.06225775],
#        [ 2.44948974,  2.        ,  7.14142843]])

請注意, sumaxis=2 ,這意味着將總和乘以2x3x4數組的第三條軸(其中軸id以0開頭)。

如果您的陣列很小,那么上面的命令就可以正常工作。 但是,如果陣列很大,則可能會遇到內存問題。 請注意,在上面的示例中,numpy內部創建了一個2x3x4數組來執行廣播。 如果我們推廣一個有尺寸axz和B具有尺寸bxz ,然后numpy的將在內部創建一個axbxz用於廣播陣列。

我們可以通過進行一些數學操作來避免創建此中間數組。 因為您將歐幾里德距離計算為平方差之和,所以我們可以利用數學事實,即平方差之和可以被重寫。

在此處輸入圖片說明

注意,中間項涉及元素與乘法的和。 乘法的總和被稱為點積。 由於A和B均為矩陣,因此此操作實際上是矩陣乘法。 因此,我們可以將以上內容重寫為:

在此處輸入圖片說明

然后,我們可以編寫以下numpy代碼:

threeSums = np.sum(np.square(A)[:,np.newaxis,:], axis=2) - 2 * A.dot(B.T) + np.sum(np.square(B), axis=1)
dist = np.sqrt(threeSums)
dist
# array([[ 3.74165739,  0.        ,  8.06225775],
#        [ 2.44948974,  2.        ,  7.14142843]])

請注意,上面的答案與先前的實現完全相同。 同樣,這里的優點是我們不需要創建用於廣播的中間2x3x4陣列。

為了完整threeSums ,讓我們threeSums檢查threeSums中每個被求和的維數threeSums允許廣播。

np.sum(np.square(A)[:,np.newaxis,:], axis=2) has dimensions 2 x 1
2 * A.dot(B.T) has dimensions                               2 x 3
np.sum(np.square(B), axis=1) has dimensions                 1 x 3

因此,正如預期的那樣,最終的dist數組的尺寸為2x3。

本教程還討論了使用點積代替元素乘積之和。

最近在使用深度學習時遇到了相同的問題(stanford cs231n,Assignment1),但是當我使用

 np.sqrt((np.square(a[:,np.newaxis]-b).sum(axis=2)))

發生錯誤

MemoryError

這意味着我內存不足(實際上,它在中間產生了一個500 * 5000 * 1024的數組。它是如此之大!)

為了防止出現該錯誤,我們可以使用公式來簡化:

碼:

import numpy as np
aSumSquare = np.sum(np.square(a),axis=1);
bSumSquare = np.sum(np.square(b),axis=1);
mul = np.dot(a,b.T);
dists = np.sqrt(aSumSquare[:,np.newaxis]+bSumSquare-2*mul)

只需在正確的位置使用np.newaxis

 np.sqrt((np.square(a[:,np.newaxis]-b).sum(axis=2)))

此功能已包含在scipy的空間模塊中 ,我建議使用它,因為它將在引擎蓋下進行矢量化和高度優化。 但是,從其他答案可以明顯看出,您可以通過多種方式自己執行此操作。

import numpy as np
a = np.array([[1,1,1,1],[2,2,2,2]])
b = np.array([[1,2,3,4],[1,1,1,1],[1,2,1,9]])
np.sqrt((np.square(a[:,np.newaxis]-b).sum(axis=2)))
# array([[ 3.74165739,  0.        ,  8.06225775],
#       [ 2.44948974,  2.        ,  7.14142843]])
from scipy.spatial.distance import cdist
cdist(a,b)
# array([[ 3.74165739,  0.        ,  8.06225775],
#       [ 2.44948974,  2.        ,  7.14142843]])

使用numpy.linalg.norm也可以很好地與廣播配合使用。 axis指定整數值將使用矢量范數,默認為歐幾里得范數。

import numpy as np

a = np.array([[1,1,1,1],[2,2,2,2]])
b = np.array([[1,2,3,4],[1,1,1,1],[1,2,1,9]])
np.linalg.norm(a[:, np.newaxis] - b, axis = 2)

# array([[ 3.74165739,  0.        ,  8.06225775],
#       [ 2.44948974,  2.        ,  7.14142843]])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM