简体   繁体   English

将两个字典与 numpy 矩阵作为值进行比较

[英]Comparing two dictionaries with numpy matrices as values

I want to assert that two Python dictionaries are equal (that means: equal amount of keys, and each mapping from key to value is equal; order is not important).我想断言两个 Python 字典是相等的(这意味着:键的数量相等,并且从键到值的每个映射都是相等的;顺序并不重要)。 A simple way would be assert A==B , however, this does not work if the values of the dictionaries are numpy arrays .一个简单的方法是assert A==B ,但是,如果字典的值是numpy arrays ,则这不起作用。 How can I write a function to check in general if two dictionaries are equal?如何编写一个函数来检查两个字典是否相等?

>>> import numpy as np
>>> A = {1: np.identity(5)}
>>> B = {1: np.identity(5) + np.ones([5,5])}
>>> A == B
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

EDIT I am aware that numpy matrices shall be checked for equality with .all() .编辑我知道应该检查 numpy 矩阵与.all()是否相等。 What I am looking for is a general way to check for this, without having to check isinstance(np.ndarray) .我正在寻找的是一种检查这一点的通用方法,而无需检查isinstance(np.ndarray) Would this be possible?这可能吗?

Related topics without numpy arrays:没有 numpy 数组的相关主题:

I'm going to answer the half-question hidden in your question's title and first half, because frankly, this is a much more common problem to be solved and the existing answers don't address it very well. 我将回答你问题标题和上半部分中隐藏的半问题,因为坦率地说,这是一个更常见的问题需要解决,而现有的答案并没有很好地解决。 This question is " How do I compare two dicts of numpy arrays for equality "? 这个问题是“ 我如何比较两个numpy数组的平等 ”?

The first part of the problem is checking the dicts "from afar": see that their keys are the same. 问题的第一部分是检查“远方”的词:看到他们的键是相同的。 If all the keys are the same, the second part is comparing each corresponding value. 如果所有键都相同,则第二部分是比较每个对应的值。

Now the subtle issue is that a lot of numpy arrays are not integer-valued, and double-precision is imprecise . 现在微妙的问题是很多numpy数组不是整数值, 双精度是不精确的 So unless you have integer-valued (or other non-float-like) arrays you will probably want to check that the values are almost the same, ie within machine precision. 因此,除非您有整数值(或其他非浮点数)数组,否则您可能需要检查值是否几乎相同,即在机器精度内。 So in this case you wouldn't use np.array_equal (which checks exact numerical equality), but rather np.allclose (which uses a finite tolerance for the relative and absolute error between two arrays). 因此,在这种情况下,您不会使用np.array_equal (它检查精确的数字相等),而是使用np.allclose (它对两个数组之间的相对和绝对误差使用有限容差)。

The first one and a half parts of the problem are straightforward: check that the keys of the dicts agree, and use a generator comprehension to compare every value (and use all outside the comprehension to verify that each item is the same): 问题的前半部分是直截了当的:检查dicts的键是否一致,并使用生成器理解来比较每个值(并使用all理解之外的all值来验证每个项目是否相同):

import numpy as np

# some dummy data

# these are equal exactly
dct1 = {'a': np.array([2, 3, 4])}
dct2 = {'a': np.array([2, 3, 4])}

# these are equal _roughly_
dct3 = {'b': np.array([42.0, 0.2])}
dct4 = {'b': np.array([42.0, 3*0.1 - 0.1])}  # still 0.2, right?

def compare_exact(first, second):
    """Return whether two dicts of arrays are exactly equal"""
    if first.keys() != second.keys():
        return False
    return all(np.array_equal(first[key], second[key]) for key in first)

def compare_approximate(first, second):
    """Return whether two dicts of arrays are roughly equal"""
    if first.keys() != second.keys():
        return False
    return all(np.allclose(first[key], second[key]) for key in first)

# let's try them:
print(compare_exact(dct1, dct2))  # True
print(compare_exact(dct3, dct4))  # False
print(compare_approximate(dct3, dct4))  # True

As you can see in the above example, the integer arrays compare fine exactly, and depending on what you're doing (or if you're lucky) it could even work for floats. 正如您在上面的示例中所看到的,整数数组的确比较精确,并且取决于您正在做什么(或者如果您很幸运),它甚至可以用于浮点数。 But if your floats are the result of any kind of arithmetic (linear transformations for instance?) you should definitely use an approximate check. 但是,如果您的浮点数是任何算术的结果(例如线性变换?),您肯定应该使用近似检查。 For a complete description of the latter option please see the docs of numpy.allclose (and its elementwise friend, numpy.isclose ), with special regard to the rtol and atol keyword arguments. 有关后一个选项的完整描述,请参阅numpy.allclose (及其elementwise friend, numpy.isclose的文档 ,特别关注rtolatol关键字参数。

you can separate keys, values of both dicts and compare keys vs keys and values vs values: here's the solution :您可以将两个字典的键、值分开,并比较键与键和值与值:这是解决方案

import numpy as np

def dic_to_keys_values(dic):
    keys, values = list(dic.keys()), list(dic.values())
    return keys, values

def numpy_assert_almost_dict_values(dict1, dict2):
    keys1, values1 = dic_to_keys_values(dict1)
    keys2, values2 = dic_to_keys_values(dict2)
    np.testing.assert_equal(keys1, keys2)
    np.testing.assert_almost_equal(values1, values2)

dict1 = {"b": np.array([1, 2, 0.2])}
dict2 = {"b": np.array([1, 2, 3 * 0.1 - 0.1])}  # almost 0.2, but not equal
dict3 = {"b": np.array([999, 888, 444])} # completely different

numpy_assert_almost_dict_values(dict1, dict2) # no exception because almost equal
# numpy_assert_almost_dict_values(dict1, dict3) # exception because not equal

(note, above checks for exact keys and almost equal values) (注意,上面检查精确的键和几乎相等的值)

Consider this code 考虑这段代码

>>> import numpy as np
>>> np.identity(5)
array([[ 1.,  0.,  0.,  0.,  0.],
       [ 0.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  1.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  0.],
       [ 0.,  0.,  0.,  0.,  1.]])
>>> np.identity(5)+np.ones([5,5])
array([[ 2.,  1.,  1.,  1.,  1.],
       [ 1.,  2.,  1.,  1.,  1.],
       [ 1.,  1.,  2.,  1.,  1.],
       [ 1.,  1.,  1.,  2.,  1.],
       [ 1.,  1.,  1.,  1.,  2.]])
>>> np.identity(5) == np.identity(5)+np.ones([5,5])
array([[False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False]], dtype=bool)
>>> 

Note the the result of the comparison is a matrix, not a boolean value. 注意,比较的结果是矩阵,而不是布尔值。 Dict comparisons will compare values using the values cmp methods, which means that when comparing matrix values, the dict comparison will get a composite result. Dict比较将使用值cmp方法比较值,这意味着在比较矩阵值时,dict比较将得到复合结果。 What you want to do is use numpy.all to collapse the composite array result into a scalar boolean result 你想要做的是使用numpy.all将复合数组结果折叠为标量布尔结果

>>> np.all(np.identity(5) == np.identity(5)+np.ones([5,5]))
False
>>> np.all(np.identity(5) == np.identity(5))
True
>>> 

You would need to write your own function to compare these dictionaries, testing value types to see if they are matricies, and then comparing using numpy.all , otherwise using == . 您需要编写自己的函数来比较这些字典,测试值类型以查看它们是否为numpy.all ,然后使用numpy.all进行比较,否则使用== Of course, you can always get fancy and start subclassing dict and overloading cmp if you want too. 当然,如果你也想要的话,你可以随时获得幻想并开始子类化dict和重载cmp

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM