[英]How to Implement pytest.approx() for Data Classes
假设我有一个 Python 数据 Class 我想用 pytest 进行测试:
@dataclass
class ExamplePoint:
x: float
y: float
通过简单的数学运算,一切都很好:
p1 = ExamplePoint(1,2)
p2 = ExamplePoint(0.5 + 0.5, 2)
p1 == p2 # True
但是浮点运算很快就会引起问题:
p3 = ExamplePoint(1, math.sqrt(2) * math.sqrt(2))
p1 == p3 # False
使用 pytest 您可以使用approx()
function 来解决这个问题:
2.0 == approx(math.sqrt(2) * math.sqrt(2)) # True
您不能简单地将其扩展到数据 Class:
p1 = approx(p3) # Results error: TypeError: cannot make approximate comparisons to non-numeric values: ExamplePoint(x=1, y=2.0000000000000004)
我目前的解决方案是在数据 Class 上写一个approx()
function 为:
from dataclasses import astuple, dataclass
import pytest
@dataclass
class ExamplePoint:
x: float
y: float
def approx(self, other):
return astuple(self) == pytest.approx(astuple(other))
p1 = ExamplePoint(1,2)
p3 = ExamplePoint(1, math.sqrt(2) * math.sqrt(2))
p1.approx(p3) # True
我不喜欢这个解决方案,因为 ExamplePoint 现在依赖于 pytest。 这似乎是错误的。
如何扩展 pytest 以便approx()
与我的数据 Class 一起使用,而无需数据 Class 知道 Z5A748C120135ECAFE0627C?
您可以使用 math.isclose(),它还允许您设置您认为接近的容差。 在下面的代码中,我将它分别应用于数据类之外的 ExamplePoint 的 x 和 y 坐标,但您可以以不同的方式实现它:
@dataclass
class ExamplePoint:
x: float
y: float
p1 = ExamplePoint(1, 2)
p3 = ExamplePoint(1, math.sqrt(2) * math.sqrt(2))
print(math.isclose(p1.x, p3.x, rel_tol=0.01)) #True
print(math.isclose(p1.y, p3.y, rel_tol=0.01)) #True
更新:这是您如何将其合并到您的大约 function 中的方法:
from dataclasses import astuple, dataclass
import math
@dataclass
class ExamplePoint:
x: float
y: float
def approx(self, other):
return math.isclose(self.x,other.x, rel_tol=0.001) \
and math.isclose(self.y,other.y, rel_tol=0.001)
p1 = ExamplePoint(1, 2)
p2 = ExamplePoint(1,1.99)
p3 = ExamplePoint(1, math.sqrt(2) * math.sqrt(2))
print(p1.approx(p2)) #False
print(p1.approx(p3)) #True
Digging into the pytest code (see https://github.com/pytest-dev/pytest/blob/main/src/_pytest/python_api.py ), it appears that pytest checks whether the expected value is Iterable and Sizeable. 我可以使我的数据 Class Iterable 和 Sizeable 如下。
from dataclasses import astuple, dataclass
import pytest
import math
@dataclass
class ExamplePoint:
x: float
y: float
def approx(self, other):
return astuple(self) == pytest.approx(astuple(other))
def __iter__(self):
return iter(astuple(self))
def __len__(self):
return len(astuple(self))
p1 = ExamplePoint(1,2)
p3 = ExamplePoint(1, math.sqrt(2) * math.sqrt(2))
assert p1 == pytest.approx(p3) # True
您应该将 approx 添加到数据类的每个参数中:
from dataclasses import dataclass
import pytest
@dataclass
class Foo:
a: int
b: float
a = Foo(1, 3.00000001)
b = Foo(1, pytest.approx(3.0, abs=1e-3))
print(a == b)
见: https://github.com/pytest-dev/pytest/issues/6632#issuecomment-580487103
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.