[英]Searching for vectors within a numpy matrix
给定以下带有索引的矩阵ixs
,我正在 ixs 中寻找一个与ix
等效的向量(也是ixs
的行/向量),除了维度 1 (可以假定任何值)和维度 3 需要设置为1
.
ixs = np.asarray([
[0, 0, 3, 0, 1], # 0. current value of `ix`
[0, 0, 3, 1, 1], # 1.
[0, 1, 3, 0, 0], # 2.
[0, 1, 3, 0, 1], # 3.
[0, 1, 3, 1, 1], # 4.
[0, 2, 3, 0, 1], # 5.
[0, 2, 3, 1, 1] # 6.
])
ix = np.asarray([0, 0, 3, 0, 1])
因此,对于[0, 0, 3, 0, 1]
的ix
,我将查看低于该行(第 1..6 行)的所有行,并查找模式[0, *, 3, 1, 1]
即 1. [0, 0, 3, 1, 1]
, 4. [0, 1, 3, 1, 1]
, 6. [0, 2, 3, 1, 1]
。
获得这些向量的最佳(简洁)方法是什么?
此解决方案仅使用 numpy(非常快)和几个逻辑操作。 最后,它给出了正确的列。
ixs = np.matrix([
[0, 0, 3, 0, 1], # 0. current value of `ix`
[0, 0, 3, 1, 1], # 1.
[0, 1, 3, 0, 0], # 2.
[0, 1, 3, 0, 1], # 3.
[0, 1, 3, 1, 1], # 4.
[0, 2, 3, 0, 1], # 5.
[0, 2, 3, 1, 1] # 6.
])
newixs = ixs
#since the second column does not matter, we just assign it 0 in the new matrix.
newixs[:,1] = 0
#here it compares the each row against the 0 indexed row
#then, it multiplies the True and False values with 1
#and the result is 0,1 values in an array.
#then it takes the averages at the row level
#if the average is 1, then it means that all values match
mask = ((newixs == newixs[0])*1).mean(axis=1) == 1
#it then converts the matrix to array for masking
mask = np.squeeze(np.asarray(mask))
#using the mask value, we select the matched columns
ixs[mask,:]
matrix([[0, 0, 3, 0, 1],
[0, 1, 3, 0, 1],
[0, 2, 3, 0, 1]])
这是使用 cdist 的一种易于理解的方法:
我们在 ix 和每行 ixs 之间使用加权汉明距离。 如果行相同,则此距离为 0(我们使用它来仔细检查 ix 是否在 ixs 中)并为每个差异添加一个惩罚。 我们选择的权重使得 position 0,2 或 4 的差异增加了 3/11,而 position 1 或 3 的差异增加了 1/11。 稍后,我们只保留距离 < 1/4 的向量,这允许在 1 或 3 或两者处偏离 ix 的向量通过并阻止所有其他向量。 然后我们分别检查 position 3 中的 1。
from scipy.spatial.distance import cdist
# compute distance note that weights are automatically normalized to sum 1
d = cdist([ix],ixs,"hamming",w=[3,1,3,1,3])[0]
# find ix
ixloc = d.argmin()
# make sure its exactly ix
assert d[ixloc] == 0
# filter out all rows that are different in col 0,2 or 4
hits, = ((d < 1/4) & (ixs[:,3] == 1)).nonzero()
# only keep hits below the row of ix:
hits = hits[hits.searchsorted(ixloc):]
hits
# array([1, 4, 6])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.