简体   繁体   English

从子阵列中的numpy 2D数组中提取交叉数组的索引

[英]Extract indices of intersecting array from numpy 2D array in subarray

I have two 2D numpy square arrays, A and B. B is an array extracted from A where a certain number of columns and rows (with the same indices) have been stripped. 我有两个2D numpy方阵,A和B.B是从A中提取的数组,其中已经剥离了一定数量的列和行(具有相同的索引)。 Both of them are symmetric. 它们都是对称的。 For instance, A and B could be: 例如,A和B可以是:

A = np.array([[1,2,3,4,5],
              [2,7,8,9,10],
              [3,8,13,14,15],
              [4,9,14,19,20],
              [5,10,15,20,25]])
B = np.array([[1,3,5],
              [3,13,15],
              [5,15,25]])

such that the missing indices are [1,3] and intersecting indices are [0,2,4]. 这样缺失的指数是[1,3],交叉指数是[0,2,4]。

Is there a "smart" way to extract the indices in A corresponding to the rows/columns present in B that involves advanced indexing and such? 是否有一种“智能”方法来提取A中的索引,这些索引对应于B中存在的涉及高级索引的行/列等等? All I could come up with was: 我能想出的就是:

        import numpy as np
        index = np.array([],dtype=int)
        n,m = len(A),len(B)
        for j in range(n):
            k = 0
            while set(np.intersect1d(B[j],A[k])) != set(B[j]) and k<m:
                k+=1
            np.append(index,k)

which I'm aware is slow and resource-consuming when dealing with large arrays. 我知道在处理大型数组时,它很慢且耗费资源。

Thank you! 谢谢!

Edit: I did find a smarter way. 编辑:我找到了一个更聪明的方法。 I extract the diagonal from both arrays and perform the aforementioned loop on it with a simple equality check: 我从两个数组中提取对角线,并通过简单的相等检查在其上执行上述循环:

        index = []
        a = np.diag(A)
        b = np.diag(B)
        for j in range(len(b)):
            k = 0
            while a[j+k] != b[j] and k<n:
                k+=1
            index.append(k+j)

Although it still doesn't use advanced indexing and still iterates over a potentially long list, this partial solution looks cleaner and I'm going to stick with it for the time being. 虽然它仍然没有使用高级索引并且仍在一个可能很长的列表中进行迭代,但这个部分解决方案看起来更清晰,我将暂时坚持使用它。

Consider the easy case when all the values are distinct: 考虑所有值不同的简单情况:

A = np.arange(25).reshape(5,5)
ans = [1,3,4]
B = A[np.ix_(ans, ans)]

In [287]: A
Out[287]: 
array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24]])

In [288]: B
Out[288]: 
array([[ 6,  8,  9],
       [16, 18, 19],
       [21, 23, 24]])

If we test the first row of B with each row of A, we will eventually come to the comparison of [6, 8, 9] with [5, 6, 7, 8, 9] from which we can glean the candidate solution of indices [1, 3, 4] . 如果我们用A的每一行测试B的第一行,我们最终将[6, 8, 9][5, 6, 7, 8, 9] 5,6,7,8,9 [5, 6, 7, 8, 9]进行比较,我们可以从中收集候选解决方案。指数[1, 3, 4] 1,3,4 [1, 3, 4]

We can generate a set of all possible candidate solutions by pairing the first row of B with each row of A. 我们可以通过将第一行 B与A的每一行配对来生成一组所有可能的候选解。

If there is only one candidate, then we are done, since we are given that B is a submatrix of A and therefore there is always a solution. 如果只有一个候选者,那么我们就完成了,因为我们得到的是B是A的子矩阵,因此总有一个解决方案。

If there is more than one candidate, then we can do the same thing with the second row of B, and take the intersection of the candidate solutions -- After all, a solution must be a solution for each and every row of B. 如果有多个候选者,那么我们可以对B的第二行做同样的事情, 并采用候选解决方案的交集 - 毕竟,解决方案必须是B的每一行的解决方案。

Thus we can loop through the rows of B and short-circuit once we find there is only one candidate. 因此,一旦我们发现只有一个候选者,我们就可以遍历B行和短路 Again, we are assuming that B is always a submatrix of A. 同样,我们假设B总是A的子矩阵。

The find_idx function below implements the idea described above: 下面的find_idx函数实现了上述思想:

import itertools as IT
import numpy as np

def find_idx_1d(rowA, rowB):
    result = []
    if np.in1d(rowB, rowA).all():
        result = [tuple(sorted(idx)) 
                  for idx in IT.product(*[np.where(rowA==b)[0] for b in rowB])]
    return result

def find_idx(A, B):
    candidates = set([idx for row in A for idx in find_idx_1d(row, B[0])])
    for Bi in B[1:]:
        if len(candidates) == 1:
            # stop when there is a unique candidate
            return candidates.pop()
        new = [idx for row in A for idx in find_idx_1d(row, Bi)]  
        candidates = candidates.intersection(new)
    if candidates:
        return candidates.pop()
    raise ValueError('no solution found')

Correctness : The two solutions you've proposed may not always return the correct result, particularly when there are repeated values. 正确性 :您提出的两个解决方案可能并不总是返回正确的结果,尤其是在存在重复值时。 For example, 例如,

def is_solution(A, B, idx):
    return np.allclose(A[np.ix_(idx, idx)], B)

def find_idx_orig(A, B):
    index = []
    for j in range(len(B)):
        k = 0
        while k<len(A) and set(np.intersect1d(B[j],A[k])) != set(B[j]):
            k+=1
        index.append(k)
    return index

def find_idx_diag(A, B):
    index = []
    a = np.diag(A)
    b = np.diag(B)
    for j in range(len(b)):
        k = 0
        while a[j+k] != b[j] and k<len(A):
            k+=1
        index.append(k+j)
    return index

def counterexample():
    """
    Show find_idx_diag, find_idx_orig may not return the correct result
    """
    A = np.array([[1,2,0],
                  [2,1,0],
                  [0,0,1]])
    ans = [0,1]
    B = A[np.ix_(ans, ans)]
    assert not is_solution(A, B, find_idx_orig(A, B))
    assert is_solution(A, B, find_idx(A, B))

    A = np.array([[1,2,0],
                  [2,1,0],
                  [0,0,1]])
    ans = [1,2]
    B = A[np.ix_(ans, ans)]

    assert not is_solution(A, B, find_idx_diag(A, B))
    assert is_solution(A, B, find_idx(A, B))

counterexample()

Benchmark : Ignoring at our peril the issue of correctness, out of curiosity let's compare these functions on the basis of speed. 基准 :忽视我们的危险正确性问题,出于好奇,让我们在速度的基础上比较这些功能。

def make_AB(n, m):
    A = symmetrize(np.random.random((n, n)))
    ans = np.sort(np.random.choice(n, m, replace=False))
    B = A[np.ix_(ans, ans)]
    return A, B

def symmetrize(a):
    "http://stackoverflow.com/a/2573982/190597 (EOL)"
    return a + a.T - np.diag(a.diagonal())

if __name__ == '__main__':
    counterexample()
    A, B = make_AB(500, 450)
    assert is_solution(A, B, find_idx(A, B))

In [283]: %timeit find_idx(A, B)
10 loops, best of 3: 74 ms per loop

In [284]: %timeit find_idx_orig(A, B)
1 loops, best of 3: 14.5 s per loop

In [285]: %timeit find_idx_diag(A, B)
100 loops, best of 3: 2.93 ms per loop

So find_idx is much faster than find_idx_orig , but not as fast as find_idx_diag . 所以find_idx比快得多find_idx_orig ,但速度不及find_idx_diag

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM