简体   繁体   English

如何比较持有numpy.ndarray(bool(a == b)引发ValueError)的数据类的相等性?

[英]How to compare equality of dataclasses holding numpy.ndarray (bool(a==b) raises ValueError)?

If I create a Python dataclass containing a Numpy ndarray, I can no longer use the automatically generated __eq__ anymore. 如果我创建一个包含Numpy ndarray的Python数据类,我将无法再使用自动生成的__eq__

import numpy as np

@dataclass
class Instr:
    foo: np.ndarray
    bar: np.ndarray

arr = np.array([1])
arr2 = np.array([1, 2])
print(Instr(arr, arr) == Instr(arr2, arr2))

ValueError: The truth value of an array with more than one element is ambiguous. ValueError:具有多个元素的数组的真值不明确。 Use a.any() or a.all() 使用a.any()或a.all()

This is because ndarray.__eq__ sometimes returns a ndarray of truth values, by comparing a[0] to b[0] , and so on and so forth to the longer of the 2. This is quite complex and unintuitive, and in fact only raises an error when the arrays are different shapes, or have different values or something. 这是因为ndarray.__eq__ 有时会通过比较a[0]b[0]等等返回ndarray个值中的ndarray ,返回真值的ndarray ,以此类推直到2中的较长者。这是相当复杂且不直观的,并且实际上只是当数组的形状不同或具有不同的值或其他值时,会引发错误。

How do I safely compare @dataclass es holding Numpy arrays? 如何安全比较持有Numpy数组的@dataclass


@dataclass 's implementation of __eq__ is generated using eval() . @dataclass__eq__实现是使用eval()生成的。 Its source is missing from the stacktrace and cannot be viewed using inspect , but it's actually using a tuple comparison , which calls bool(foo). 它的源在stacktrace中丢失了,并且无法使用inspect进行查看,但是实际上它是在使用元组比较 (调用bool(foo))。

import dis
dis.dis(Instr.__eq__)

excerpt: 摘抄:

  3 12 LOAD_FAST 0 (self) 14 LOAD_ATTR 1 (foo) 16 LOAD_FAST 0 (self) 18 LOAD_ATTR 2 (bar) 20 BUILD_TUPLE 2 22 LOAD_FAST 1 (other) 24 LOAD_ATTR 1 (foo) 26 LOAD_FAST 1 (other) 28 LOAD_ATTR 2 (bar) 30 BUILD_TUPLE 2 32 COMPARE_OP 2 (==) 34 RETURN_VALUE 

The solution is to put in your own __eq__ method and set eq=False so the dataclass doesn't generate its own (although checking the docs that last step isn't necessary but I think it's nice to be explicit anyway). 解决方案是放入您自己的__eq__方法并设置eq=False以便数据类不会生成自己的方法(尽管检查文档不是最后一步是必要的,但无论如何我还是觉得很明确)。

import numpy as np

def array_eq(arr1, arr2):
    return (isinstance(arr1, np.ndarray) and
            isinstance(arr2, np.ndarray) and
            arr1.shape == arr2.shape and
            (arr1 == arr2).all())

@dataclass(eq=False)
class Instr:

    foo: np.ndarray
    bar: np.ndarray

    def __eq__(self, other):
        if not isinstance(other, Instr):
            return NotImplemented
        return array_eq(self.foo, other.foo) and array_eq(self.bar, other.bar)

Edit 编辑

A general and quick solution for generic dataclasses where some values are numpy arrays and some others are not 通用数据类的通用快速解决方案,其中某些值是numpy数组,而另一些则不是

import numpy as np
from dataclasses import dataclass, astuple

def array_safe_eq(a, b) -> bool:
    """Check if a and b are equal, even if they are numpy arrays"""
    if a is b:
        return True
    if isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
        return a.shape == b.shape and (a == b).all()
    try:
        return a == b
    except TypeError:
        return NotImplemented

def dc_eq(dc1, dc2) -> bool:
   """checks if two dataclasses which hold numpy arrays are equal"""
   if dc1 is dc2:
        return True
   if dc1.__class__ is not dc2.__class__:
       return NotImplmeneted  # better than False
   t1 = astuple(dc1)
   t2 = astuple(dc2)
   return all(array_safe_eq(a1, a2) for a1, a2 in zip(t1, t2))

# usage
@dataclass(eq=False)
class T:

   a: int
   b: np.ndarray
   c: np.ndarray

   def __eq__(self, other):
        return dc_eq(self, other)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM