繁体   English   中英

在 numpy 矩阵中搜索向量

[英]Searching for vectors within a numpy matrix

给定以下带有索引的矩阵ixs ,我正在 ixs 中寻找一个与ix等效的向量(也是ixs的行/向量),除了维度 1 (可以假定任何值)和维度 3 需要设置为1 .

ixs = np.asarray([
 [0, 0, 3, 0, 1], # 0. current value of `ix`
 [0, 0, 3, 1, 1], # 1.
 [0, 1, 3, 0, 0], # 2.
 [0, 1, 3, 0, 1], # 3.
 [0, 1, 3, 1, 1], # 4.
 [0, 2, 3, 0, 1], # 5.
 [0, 2, 3, 1, 1]  # 6.
])
ix = np.asarray([0, 0, 3, 0, 1])

因此,对于[0, 0, 3, 0, 1]ix ,我将查看低于该行(第 1..6 行)的所有行,并查找模式[0, *, 3, 1, 1]即 1. [0, 0, 3, 1, 1] , 4. [0, 1, 3, 1, 1] , 6. [0, 2, 3, 1, 1]

获得这些向量的最佳(简洁)方法是什么?

此解决方案仅使用 numpy(非常快)和几个逻辑操作。 最后,它给出了正确的列。

ixs = np.matrix([
 [0, 0, 3, 0, 1], # 0. current value of `ix`
 [0, 0, 3, 1, 1], # 1.
 [0, 1, 3, 0, 0], # 2.
 [0, 1, 3, 0, 1], # 3.
 [0, 1, 3, 1, 1], # 4.
 [0, 2, 3, 0, 1], # 5.
 [0, 2, 3, 1, 1]  # 6.
])

newixs = ixs

#since the second column does not matter, we just assign it 0 in the new matrix.

newixs[:,1] = 0 

#here it compares the each row against the 0 indexed row
#then, it multiplies the True and False values with 1
#and the result is 0,1 values in an array. 
#then it takes the averages at the row level
#if the average is 1, then it means that all values match

mask = ((newixs == newixs[0])*1).mean(axis=1) == 1

#it then converts the matrix to array for masking
mask = np.squeeze(np.asarray(mask))

#using the mask value, we select the matched columns
ixs[mask,:]
matrix([[0, 0, 3, 0, 1],
        [0, 1, 3, 0, 1],
        [0, 2, 3, 0, 1]])

这是使用 cdist 的一种易于理解的方法:

我们在 ix 和每行 ixs 之间使用加权汉明距离。 如果行相同,则此距离为 0(我们使用它来仔细检查 ix 是否在 ixs 中)并为每个差异添加一个惩罚。 我们选择的权重使得 position 0,2 或 4 的差异增加了 3/11,而 position 1 或 3 的差异增加了 1/11。 稍后,我们只保留距离 < 1/4 的向量,这允许在 1 或 3 或两者处偏离 ix 的向量通过并阻止所有其他向量。 然后我们分别检查 position 3 中的 1。

from scipy.spatial.distance import cdist

# compute distance note that weights are automatically normalized to sum 1
d = cdist([ix],ixs,"hamming",w=[3,1,3,1,3])[0]
# find ix
ixloc = d.argmin()
# make sure its exactly ix
assert d[ixloc] == 0

# filter out all rows that are different in col 0,2 or 4
hits, = ((d < 1/4) & (ixs[:,3] == 1)).nonzero()
# only keep hits below the row of ix:
hits = hits[hits.searchsorted(ixloc):]

hits
# array([1, 4, 6])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM