简体   繁体   English

检查两个 numpy arrays 是否相同

[英]Check if two numpy arrays are identical

Suppose I have a bunch of arrays, including x and y , and I want to check if they're equal.假设我有一堆 arrays,包括xy ,我想检查它们是否相等。 Generally, I can just use np.all(x == y) (barring some dumb corner cases which I'm ignoring now).一般来说,我可以只使用np.all(x == y) (除非我现在忽略了一些愚蠢的极端情况)。

However this evaluates the entire array of (x == y) , which is usually not needed.但是,这会评估(x == y)整个数组,这通常是不需要的。 My arrays are really large, and I have a lot of them, and the probability of two arrays being equal is small, so in all likelihood, I really only need to evaluate a very small portion of (x == y) before the all function could return False, so this is not an optimal solution for me.我的 arrays 真的很大,而且我有很多,两个 arrays 相等的概率很小,所以很可能,我真的只需要在all之前评估(x == y)的一小部分function可能返回 False,所以这对我来说不是最佳解决方案。

I've tried using the builtin all function, in combination with itertools.izip : all(val1==val2 for val1,val2 in itertools.izip(x, y))我已经尝试使用内置的all function,与itertools.izip结合使用: all(val1==val2 for val1,val2 in itertools.izip(x, y))

However, that just seems much slower in the case that two arrays are equal, that overall, it's stil not worth using over np.all .但是,在两个 arrays相等的情况下,这似乎要慢得多,总的来说,它仍然不值得在np.all上使用。 I presume because of the builtin all 's general-purposeness.我推测是因为内置all的通用性。 And np.all doesn't work on generators.并且np.all不适用于生成器。

Is there a way to do what I want in a more speedy manner?有没有办法以更快的方式做我想做的事?

I know this question is similar to previously asked questions (eg Comparing two numpy arrays for equality, element-wise ) but they specifically don't cover the case of early termination.我知道这个问题类似于之前提出的问题(例如,比较两个 numpy arrays 是否相等,元素方面),但它们特别不包括提前终止的情况。

Until this is implemented in numpy natively you can write your own function and jit-compile it with numba :在 numpy 本地实现之前,您可以编写自己的函数并使用numba对其进行 jit 编译:

import numpy as np
import numba as nb


@nb.jit(nopython=True)
def arrays_equal(a, b):
    if a.shape != b.shape:
        return False
    for ai, bi in zip(a.flat, b.flat):
        if ai != bi:
            return False
    return True


a = np.random.rand(10, 20, 30)
b = np.random.rand(10, 20, 30)


%timeit np.all(a==b)  # 100000 loops, best of 3: 9.82 µs per loop
%timeit arrays_equal(a, a)  # 100000 loops, best of 3: 9.89 µs per loop
%timeit arrays_equal(a, b)  # 100000 loops, best of 3: 691 ns per loop

Worst case performance (arrays equal) is equivalent to np.all and in case of early stopping the compiled function has the potential to outperform np.all a lot.最坏情况下的性能(数组相等)相当于np.all并且在提前停止编译函数的情况下有可能比np.all性能np.all很多。

在 github上的numpy 页面上显然正在讨论将短路逻辑添加到数组比较中,因此可能会在 numpy 的未来版本中提供。

Hmmm, I know it is the poor answer but it seems there is no easy way for this.嗯,我知道这是一个糟糕的答案,但似乎没有简单的方法。 Numpy Creators should fix it. Numpy Creators 应该修复它。 I suggest:我建议:

def compare(a, b):
    if len(a) > 0 and not np.array_equal(a[0], b[0]):
        return False
    if len(a) > 15 and not np.array_equal(a[:15], b[:15]):
        return False
    if len(a) > 200 and not np.array_equal(a[:200], b[:200]):
        return False
    return np.array_equal(a, b)

:) :)

Well, not really an answer as I haven't checked if it break-circuits, but:好吧,这不是真正的答案,因为我还没有检查它是否断路,但是:

assert_array_equal . assert_array_equal

From the documentation:从文档:

Raises an AssertionError if two array_like objects are not equal.如果两个array_like对象不相等,则引发 AssertionError。

Try Except it if not on a performance sensitive code path.如果不在性能敏感的代码路径上, Try Except

Or follow the underlying source code, maybe it's efficient.或者按照底层的源代码,也许它是有效的。

You could iterate all elements of the arrays and check if they are equal.您可以迭代数组的所有元素并检查它们是否相等。 If the arrays are most likely not equal it will return much faster than the .all function.如果数组很可能不相等,它将比 .all 函数返回得快得多。 Something like this:像这样的东西:

import numpy as np

a = np.array([1, 2, 3])
b = np.array([1, 3, 4])

areEqual = True

for x in range(0, a.size-1):
        if a[x] != b[x]:
                areEqual = False
                break
        else:
               print "a[x] is equal to b[x]\n"

if areEqual:
        print "The tables are equal\n"
else:
        print "The tables are not equal\n"

Probably someone who understands the underlying data structure could optimize this or explain whether it's reliable/safe/good practice, but it seems to work.可能了解底层数据结构的人可以优化它或解释它是否可靠/安全/良好实践,但它似乎有效。

np.all(a==b)
Out[]: True

memoryview(a.data)==memoryview(b.data)
Out[]: True

%timeit np.all(a==b)
The slowest run took 10.82 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 6.2 µs per loop

%timeit memoryview(a.data)==memoryview(b.data)
The slowest run took 8.55 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 1.85 µs per loop

If I understand this correctly, ndarray.data creates a pointer to the data buffer and memoryview creates a native python type that can be short-circuited out of the buffer.如果我理解正确, ndarray.data会创建一个指向数据缓冲区的指针,而memoryview会创建一个可以从缓冲区短路的本机 python 类型。

I think.我认为。

EDIT: further testing shows it may not be as big a time-improvement as shown.编辑:进一步的测试表明它可能不像所示的时间改进那么大。 previously a=b=np.eye(5)以前a=b=np.eye(5)

a=np.random.randint(0,10,(100,100))

b=a.copy()

%timeit np.all(a==b)
The slowest run took 6.70 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 17.7 µs per loop

%timeit memoryview(a.data)==memoryview(b.data)
10000 loops, best of 3: 30.1 µs per loop

np.all(a==b)
Out[]: True

memoryview(a.data)==memoryview(b.data)
Out[]: True

As Thomas Kühn wrote in a comment to your post, array_equal is a function which should solve the problem.正如 Thomas Kühn 在对您的帖子的评论中所写, array_equal是一个应该可以解决问题的函数。 It is described in Numpy's API reference .它在Numpy 的 API 参考中有所描述。

Breaking down the original problem to three parts: "(1) My arrays are really large , and (2) I have a lot of them , and (3) the probability of two arrays being equal is small "将原始问题分解为三个部分:“(1)我的 arrays真的很大,(2)我有很多,(3)两个 arrays 相等的概率很小

All the solutions (to date) are focused on part (1) - optimizing the performance of each equality check, and some improve this performance by factor of 10. Points (2) and (3) are ignored.所有的解决方案(迄今为止)都集中在第 (1) 部分 - 优化每个相等检查的性能,并且一些将这种性能提高了 10 倍。第 (2) 和 (3) 点被忽略。 Comparing each pair has O(n^2) complexity, which can become huge for a lot of matrices, while needles as the probability of being duplicates is very small.比较每一对具有 O(n^2) 复杂度,对于很多矩阵来说这可能会变得很大,而针作为重复的概率非常小。

The check can become much faster with the following general algorithm -使用以下通用算法,检查可以变得更快 -

  • fast hash of each array O(n)每个数组的快速 hash O(n)
  • check equality only for arrays with the same hash仅检查具有相同 hash 的 arrays 的相等性

A good hash is almost unique, so the number of keys can easily be a very large fraction of n.一个好的 hash 几乎是唯一的,因此键的数量很容易成为 n 的很大一部分。 On average, number of arrays with the same hash will be very small, and almost 1 in some cases.平均而言,具有相同 hash 的 arrays 的数量将非常少,在某些情况下几乎为 1。 Duplicate arrays will have the same hash, while having the same hash doesn't guarantee they are duplicates.重复的 arrays 将具有相同的 hash,而具有相同的 hash 并不能保证它们是重复的。 In that sense, the algorithm will catch all the duplicates.从这个意义上说,该算法将捕获所有重复项。 Comparing images only with the same hash significantly reduces the number of comparisons, which becomes almost O(n)仅用相同的 hash 比较图像显着减少了比较次数,几乎变成了 O(n)

For my problem, I had to check duplicates within ~1 million integer arrays, each with 10k elements.对于我的问题,我必须检查大约 100 万个 integer arrays 内的重复项,每个都有 10k 个元素。 Optimizing only the array equality check (with @MB-F solution) estimated run time was 5 days.仅优化数组相等性检查(使用@MB-F解决方案)估计运行时间为 5 天。 With hashing first it finished in minutes.首先散列,它在几分钟内完成。 (I used array sum as the hash, that was suited for my arrays characteristics) (我使用数组和作为 hash,这适合我的 arrays 特性)

Some psuedo-python code一些伪python代码


def fast_hash(arr) -> int:
    pass

def arrays_equal(arr1, arr2) -> bool:
    pass

def make_hash_dict(array_stack, hush_fn=np.sum):

    hash_dict = defaultdict(list)
    hashes = np.squeeze(np.apply_over_axes(hush_fn, array_stack, range(1, array_stack.ndim)))
    for idx, hash_val in enumerate(hashes):
        hash_dict[hash_val].append(idx)

    return hash_dict

def get_duplicate_sets(hash_dict, array_stack):

    duplicate_sets = []
    for hash_key, ind_list in hash_dict.items():
        if len(ind_list) == 1:
            continue

        all_duplicates = []
        for idx1 in range(len(ind_list)):
            v1 = ind_list[idx1]
            if v1 in all_duplicates:
                continue

            arr1 = array_stack[v1]
            curr_duplicates = []
            for idx2 in range(idx1+1, len(ind_list)):
                v2 = ind_list[idx2]
                arr2 = array_stack[v2]
                if arrays_equal(arr1, arr2):
                    if len(curr_duplicates) == 0:
                        curr_duplicates.append(v1)
                    curr_duplicates.append(v2)
            
            if len(curr_duplicates) > 0:
                all_duplicates.extend(curr_duplicates)
                duplicate_sets.append(curr_duplicates)

    return duplicate_sets


The variable duplicate_sets is a list of lists, each internal list contains indices of all the same duplicates.变量duplicate_sets是一个列表列表,每个内部列表都包含所有相同副本的索引。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM