簡體   English   中英

在 numpy 矩陣中搜索向量

[英]Searching for vectors within a numpy matrix

給定以下帶有索引的矩陣ixs ,我正在 ixs 中尋找一個與ix等效的向量(也是ixs的行/向量),除了維度 1 (可以假定任何值)和維度 3 需要設置為1 .

ixs = np.asarray([
 [0, 0, 3, 0, 1], # 0. current value of `ix`
 [0, 0, 3, 1, 1], # 1.
 [0, 1, 3, 0, 0], # 2.
 [0, 1, 3, 0, 1], # 3.
 [0, 1, 3, 1, 1], # 4.
 [0, 2, 3, 0, 1], # 5.
 [0, 2, 3, 1, 1]  # 6.
])
ix = np.asarray([0, 0, 3, 0, 1])

因此,對於[0, 0, 3, 0, 1]ix ,我將查看低於該行(第 1..6 行)的所有行,並查找模式[0, *, 3, 1, 1]即 1. [0, 0, 3, 1, 1] , 4. [0, 1, 3, 1, 1] , 6. [0, 2, 3, 1, 1]

獲得這些向量的最佳(簡潔)方法是什么?

此解決方案僅使用 numpy(非常快)和幾個邏輯操作。 最后,它給出了正確的列。

ixs = np.matrix([
 [0, 0, 3, 0, 1], # 0. current value of `ix`
 [0, 0, 3, 1, 1], # 1.
 [0, 1, 3, 0, 0], # 2.
 [0, 1, 3, 0, 1], # 3.
 [0, 1, 3, 1, 1], # 4.
 [0, 2, 3, 0, 1], # 5.
 [0, 2, 3, 1, 1]  # 6.
])

newixs = ixs

#since the second column does not matter, we just assign it 0 in the new matrix.

newixs[:,1] = 0 

#here it compares the each row against the 0 indexed row
#then, it multiplies the True and False values with 1
#and the result is 0,1 values in an array. 
#then it takes the averages at the row level
#if the average is 1, then it means that all values match

mask = ((newixs == newixs[0])*1).mean(axis=1) == 1

#it then converts the matrix to array for masking
mask = np.squeeze(np.asarray(mask))

#using the mask value, we select the matched columns
ixs[mask,:]
matrix([[0, 0, 3, 0, 1],
        [0, 1, 3, 0, 1],
        [0, 2, 3, 0, 1]])

這是使用 cdist 的一種易於理解的方法:

我們在 ix 和每行 ixs 之間使用加權漢明距離。 如果行相同,則此距離為 0(我們使用它來仔細檢查 ix 是否在 ixs 中)並為每個差異添加一個懲罰。 我們選擇的權重使得 position 0,2 或 4 的差異增加了 3/11,而 position 1 或 3 的差異增加了 1/11。 稍后,我們只保留距離 < 1/4 的向量,這允許在 1 或 3 或兩者處偏離 ix 的向量通過並阻止所有其他向量。 然后我們分別檢查 position 3 中的 1。

from scipy.spatial.distance import cdist

# compute distance note that weights are automatically normalized to sum 1
d = cdist([ix],ixs,"hamming",w=[3,1,3,1,3])[0]
# find ix
ixloc = d.argmin()
# make sure its exactly ix
assert d[ixloc] == 0

# filter out all rows that are different in col 0,2 or 4
hits, = ((d < 1/4) & (ixs[:,3] == 1)).nonzero()
# only keep hits below the row of ix:
hits = hits[hits.searchsorted(ixloc):]

hits
# array([1, 4, 6])

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM