簡體   English   中英

使用NumPy argsort並接收2D數組

[英]Using NumPy argsort and take in 2D arrays

目的是計算兩組點之間的距離矩陣( set1set2 ),使用argsort()獲取已排序的索引,使用take()來提取已排序的數組。 我知道我可以直接進行sort() ,但是我需要為下一步做索引。

我正在使用這里討論的花哨的索引概念 我無法直接使用獲取的索引矩陣使用take() ,但是向每行添加相應的數量使其工作,因為take()平源數組,使第二行元素具有索引+ = len(set2 ),第三行索引+ = 2 * len(set2)等等(見下文):

dist  = np.subtract.outer( set1[:,0], set2[:,0] )**2
dist += np.subtract.outer( set1[:,1], set2[:,1] )**2
dist += np.subtract.outer( set1[:,2], set2[:,2] )**2
a = np.argsort( dist, axis=1 )
a += np.array([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
               [10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
               [20, 20, 20, 20, 20, 20, 20, 20, 20, 20],
               [30, 30, 30, 30, 30, 30, 30, 30, 30, 30]])
s1 = np.sort(dist,axis=1)
s2 = np.take(dist,a)
np.nonzero((s1-s2)) == False
#True # meaning that it works...

主要問題是:有沒有直接的方法來使用take()而不總結這些索引?

要播放的數據:

set1 = np.array([[ 250., 0.,    0.],
                 [ 250., 0.,  510.],
                 [-250., 0.,    0.],
                 [-250., 0.,    0.]])

set2 = np.array([[  61.0, 243.1, 8.3],
                 [ -43.6, 246.8, 8.4],
                 [ 102.5, 228.8, 8.4],
                 [  69.5, 240.9, 8.4],
                 [ 133.4, 212.2, 8.4],
                 [ -52.3, 245.1, 8.4],
                 [-125.8, 216.8, 8.5],
                 [-154.9, 197.1, 8.6],
                 [  61.0, 243.1, 8.7],
                 [ -26.2, 249.3, 8.7]])

其他相關問題:

- 兩個不同的Numpy陣列中的點之間的歐幾里德距離,而不在其中

我不認為有一種方法可以使用np.take而不用平坦的指數。 由於維度可能會發生變化,因此最好使用np.ravel_multi_index ,執行以下操作:

a = np.argsort(dist, axis=1)
a = np.ravel_multi_index((np.arange(dist.shape[0])[:, None], a), dims=dist.shape)

或者,您可以使用花式索引而不使用take

s2 = dist[np.arange(4)[:, None], a]

截至2018年5月,有np.take_along_axis

s2 = np.take_along_axis(dist, a, axis=1)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM