[英]assertAlmostEqual in Python unit-test for collections of floats
Python 單元測試框架中的assertAlmostEqual(x, y)方法測試x
和y
是否近似相等(假設它們是浮點數)。
assertAlmostEqual()
的問題在於它只適用於浮點數。 我正在尋找一種像assertAlmostEqual()
這樣的方法,它適用於浮點數列表、浮點數集、浮點數字典、浮點數元組、浮點數元組列表、浮點數列表集等。
例如,令x = 0.1234567890
, y = 0.1234567891
。 x
和y
幾乎相等,因為它們在除最后一位以外的每個數字上都一致。 因此self.assertAlmostEqual(x, y)
是True
因為assertAlmostEqual()
適用於浮點數。
我正在尋找更通用的assertAlmostEquals()
,它還評估以下對True
的調用:
self.assertAlmostEqual_generic([x, x, x], [y, y, y])
。self.assertAlmostEqual_generic({1: x, 2: x, 3: x}, {1: y, 2: y, 3: y})
。self.assertAlmostEqual_generic([(x,x)], [(y,y)])
。有這樣的方法還是我必須自己實現?
說明:
assertAlmostEquals()
有一個名為places
的可選參數,通過計算四舍五入到places
位數的差異來比較數字。 默認情況下places=7
,因此self.assertAlmostEqual(0.5, 0.4)
為假,而self.assertAlmostEqual(0.12345678, 0.12345679)
為真。 我推測的assertAlmostEqual_generic()
應該具有相同的功能。
如果兩個列表在完全相同的順序中具有幾乎相等的數字,則認為它們幾乎相等。 正式地, for i in range(n): self.assertAlmostEqual(list1[i], list2[i])
。
類似地,如果兩個集合可以轉換為幾乎相等的列表(通過為每個集合分配一個順序),則它們被認為是幾乎相等的。
類似地,如果每個字典的鍵集幾乎等於另一個字典的鍵集,則兩個字典被認為幾乎相等,並且對於每個這樣幾乎相等的鍵對,都有一個對應的幾乎相等的值。
一般來說:我認為兩個 collections 幾乎相等,如果它們相等,除了一些相應的浮點數幾乎相等。 換句話說,我想真正比較對象,但在沿途比較浮點數時精度較低(自定義)。
如果您不介意使用 NumPy(隨 Python(x,y) 一起提供),您可能需要查看np.testing
模塊,其中定義了一個assert_almost_equal
函數。
簽名是np.testing.assert_almost_equal(actual, desired, decimal=7, err_msg='', verbose=True)
>>> x = 1.000001
>>> y = 1.000002
>>> np.testing.assert_almost_equal(x, y)
AssertionError:
Arrays are not almost equal to 7 decimals
ACTUAL: 1.000001
DESIRED: 1.000002
>>> np.testing.assert_almost_equal(x, y, 5)
>>> np.testing.assert_almost_equal([x, x, x], [y, y, y], 5)
>>> np.testing.assert_almost_equal((x, x, x), (y, y, y), 5)
從 python 3.5 開始,您可以比較使用
math.isclose(a, b, rel_tol=1e-9, abs_tol=0.0)
如pep-0485 中所述。 實現應該等同於
abs(a-b) <= max( rel_tol * max(abs(a), abs(b)), abs_tol )
下面是我如何實現一個通用的is_almost_equal(first, second)
函數:
首先,復制您需要比較的對象( first
和second
),但不要制作精確的副本:刪除您在對象內遇到的任何浮點數的無意義的十進制數字。
既然你有first
和second
副本,其中無意義的十進制數字消失了,只需使用==
運算符比較first
和second
。
假設我們有一個cut_insignificant_digits_recursively(obj, places)
這復制功能obj
,但只保留了places
各浮在原有的最顯著十進制數字obj
。 這是is_almost_equals(first, second, places)
的工作實現:
from insignificant_digit_cutter import cut_insignificant_digits_recursively
def is_almost_equal(first, second, places):
'''returns True if first and second equal.
returns true if first and second aren't equal but have exactly the same
structure and values except for a bunch of floats which are just almost
equal (floats are almost equal if they're equal when we consider only the
[places] most significant digits of each).'''
if first == second: return True
cut_first = cut_insignificant_digits_recursively(first, places)
cut_second = cut_insignificant_digits_recursively(second, places)
return cut_first == cut_second
這是cut_insignificant_digits_recursively(obj, places)
的工作實現:
def cut_insignificant_digits(number, places):
'''cut the least significant decimal digits of a number,
leave only [places] decimal digits'''
if type(number) != float: return number
number_as_str = str(number)
end_of_number = number_as_str.find('.')+places+1
if end_of_number > len(number_as_str): return number
return float(number_as_str[:end_of_number])
def cut_insignificant_digits_lazy(iterable, places):
for obj in iterable:
yield cut_insignificant_digits_recursively(obj, places)
def cut_insignificant_digits_recursively(obj, places):
'''return a copy of obj except that every float loses its least significant
decimal digits remaining only [places] decimal digits'''
t = type(obj)
if t == float: return cut_insignificant_digits(obj, places)
if t in (list, tuple, set):
return t(cut_insignificant_digits_lazy(obj, places))
if t == dict:
return {cut_insignificant_digits_recursively(key, places):
cut_insignificant_digits_recursively(val, places)
for key,val in obj.items()}
return obj
代碼及其單元測試可在此處獲得: https : //github.com/snakile/approximate_comparator 。 我歡迎任何改進和錯誤修復。
如果你不介意使用numpy
包,那么numpy.testing
有assert_array_almost_equal
方法。
這適用於array_like
對象,因此它適用於array_like
的數組、列表和元組,但不適用於集合和字典。
文檔在這里。
沒有這樣的方法,你必須自己做。
對於列表和元組,定義是顯而易見的,但請注意,您提到的其他情況並不明顯,因此不提供此類功能也就不足為奇了。 例如, {1.00001: 1.00002}
幾乎等於{1.00002: 1.00001}
嗎? 處理此類情況需要選擇是否接近取決於鍵或值或兩者。 對於集合,您不太可能找到有意義的定義,因為集合是無序的,因此沒有“對應”元素的概念。
您可能必須自己實現它,雖然確實可以以相同的方式迭代列表和集合,但字典是另一回事,您迭代它們的鍵而不是值,第三個示例對我來說似乎有點模棱兩可,您的意思是比較集合中的每個值,或每個集合中的每個值。
這是一個簡單的代碼片段。
def almost_equal(value_1, value_2, accuracy = 10**-8):
return abs(value_1 - value_2) < accuracy
x = [1,2,3,4]
y = [1,2,4,5]
assert all(almost_equal(*values) for values in zip(x, y))
我自己看了一下,我將 UnitTest 庫的addTypeEqualityFunc方法與math.isclose
結合使用。
示例設置:
import math
from unittest import TestCase
class SomeFixtures(TestCase):
@classmethod
def float_comparer(cls, a, b, msg=None):
if len(a) != len(b):
raise cls.failureException(msg)
if not all(map(lambda args: math.isclose(*args), zip(a, b))):
raise cls.failureException(msg)
def some_test(self):
self.addTypeEqualityFunc(list, self.float_comparer)
self.assertEqual([1.0, 2.0, 3.0], [1.0, 2.0, 3.0])
這些答案都不適合我。 以下代碼適用於 Python 集合、類、數據類和命名元組。 我可能忘記了一些東西,但到目前為止這對我有用。
import unittest
from collections import namedtuple, OrderedDict
from dataclasses import dataclass
from typing import Any
def are_almost_equal(o1: Any, o2: Any, max_abs_ratio_diff: float, max_abs_diff: float) -> bool:
"""
Compares two objects by recursively walking them trough. Equality is as usual except for floats.
Floats are compared according to the two measures defined below.
:param o1: The first object.
:param o2: The second object.
:param max_abs_ratio_diff: The maximum allowed absolute value of the difference.
`abs(1 - (o1 / o2)` and vice-versa if o2 == 0.0. Ignored if < 0.
:param max_abs_diff: The maximum allowed absolute difference `abs(o1 - o2)`. Ignored if < 0.
:return: Whether the two objects are almost equal.
"""
if type(o1) != type(o2):
return False
composite_type_passed = False
if hasattr(o1, '__slots__'):
if len(o1.__slots__) != len(o2.__slots__):
return False
if any(not are_almost_equal(getattr(o1, s1), getattr(o2, s2),
max_abs_ratio_diff, max_abs_diff)
for s1, s2 in zip(sorted(o1.__slots__), sorted(o2.__slots__))):
return False
else:
composite_type_passed = True
if hasattr(o1, '__dict__'):
if len(o1.__dict__) != len(o2.__dict__):
return False
if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
for ((k1, v1), (k2, v2))
in zip(sorted(o1.__dict__.items()), sorted(o2.__dict__.items()))
if not k1.startswith('__')): # avoid infinite loops
return False
else:
composite_type_passed = True
if isinstance(o1, dict):
if len(o1) != len(o2):
return False
if any(not are_almost_equal(k1, k2, max_abs_ratio_diff, max_abs_diff)
or not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
for ((k1, v1), (k2, v2)) in zip(sorted(o1.items()), sorted(o2.items()))):
return False
elif any(issubclass(o1.__class__, c) for c in (list, tuple, set)):
if len(o1) != len(o2):
return False
if any(not are_almost_equal(v1, v2, max_abs_ratio_diff, max_abs_diff)
for v1, v2 in zip(o1, o2)):
return False
elif isinstance(o1, float):
if o1 == o2:
return True
else:
if max_abs_ratio_diff > 0: # if max_abs_ratio_diff < 0, max_abs_ratio_diff is ignored
if o2 != 0:
if abs(1.0 - (o1 / o2)) > max_abs_ratio_diff:
return False
else: # if both == 0, we already returned True
if abs(1.0 - (o2 / o1)) > max_abs_ratio_diff:
return False
if 0 < max_abs_diff < abs(o1 - o2): # if max_abs_diff < 0, max_abs_diff is ignored
return False
return True
else:
if not composite_type_passed:
return o1 == o2
return True
class EqualityTest(unittest.TestCase):
def test_floats(self) -> None:
o1 = ('hi', 3, 3.4)
o2 = ('hi', 3, 3.400001)
self.assertTrue(are_almost_equal(o1, o2, 0.0001, 0.0001))
self.assertFalse(are_almost_equal(o1, o2, 0.00000001, 0.00000001))
def test_ratio_only(self):
o1 = ['hey', 10000, 123.12]
o2 = ['hey', 10000, 123.80]
self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
self.assertFalse(are_almost_equal(o1, o2, 0.001, -1))
def test_diff_only(self):
o1 = ['hey', 10000, 1234567890.12]
o2 = ['hey', 10000, 1234567890.80]
self.assertTrue(are_almost_equal(o1, o2, -1, 1))
self.assertFalse(are_almost_equal(o1, o2, -1, 0.1))
def test_both_ignored(self):
o1 = ['hey', 10000, 1234567890.12]
o2 = ['hey', 10000, 0.80]
o3 = ['hi', 10000, 0.80]
self.assertTrue(are_almost_equal(o1, o2, -1, -1))
self.assertFalse(are_almost_equal(o1, o3, -1, -1))
def test_different_lengths(self):
o1 = ['hey', 1234567890.12, 10000]
o2 = ['hey', 1234567890.80]
self.assertFalse(are_almost_equal(o1, o2, 1, 1))
def test_classes(self):
class A:
d = 12.3
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
o1 = A(2.34, 'str', {1: 'hey', 345.23: [123, 'hi', 890.12]})
o2 = A(2.34, 'str', {1: 'hey', 345.231: [123, 'hi', 890.121]})
self.assertTrue(are_almost_equal(o1, o2, 0.1, 0.1))
self.assertFalse(are_almost_equal(o1, o2, 0.0001, 0.0001))
o2.hello = 'hello'
self.assertFalse(are_almost_equal(o1, o2, -1, -1))
def test_namedtuples(self):
B = namedtuple('B', ['x', 'y'])
o1 = B(3.3, 4.4)
o2 = B(3.4, 4.5)
self.assertTrue(are_almost_equal(o1, o2, 0.2, 0.2))
self.assertFalse(are_almost_equal(o1, o2, 0.001, 0.001))
def test_classes_with_slots(self):
class C(object):
__slots__ = ['a', 'b']
def __init__(self, a, b):
self.a = a
self.b = b
o1 = C(3.3, 4.4)
o2 = C(3.4, 4.5)
self.assertTrue(are_almost_equal(o1, o2, 0.3, 0.3))
self.assertFalse(are_almost_equal(o1, o2, -1, 0.01))
def test_dataclasses(self):
@dataclass
class D:
s: str
i: int
f: float
@dataclass
class E:
f2: float
f4: str
d: D
o1 = E(12.3, 'hi', D('hello', 34, 20.01))
o2 = E(12.1, 'hi', D('hello', 34, 20.0))
self.assertTrue(are_almost_equal(o1, o2, -1, 0.4))
self.assertFalse(are_almost_equal(o1, o2, -1, 0.001))
o3 = E(12.1, 'hi', D('ciao', 34, 20.0))
self.assertFalse(are_almost_equal(o2, o3, -1, -1))
def test_ordereddict(self):
o1 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.12]})
o2 = OrderedDict({1: 'hey', 345.23: [123, 'hi', 890.0]})
self.assertTrue(are_almost_equal(o1, o2, 0.01, -1))
self.assertFalse(are_almost_equal(o1, o2, 0.0001, -1))
我仍然會使用self.assertEqual()
因為它會在粉絲遇到self.assertEqual()
時保持最有用的信息。 你可以通過舍入來做到這一點,例如。
self.assertEqual(round_tuple((13.949999999999999, 1.121212), 2), (13.95, 1.12))
round_tuple
在哪里
def round_tuple(t: tuple, ndigits: int) -> tuple:
return tuple(round(e, ndigits=ndigits) for e in t)
def round_list(l: list, ndigits: int) -> list:
return [round(e, ndigits=ndigits) for e in l]
根據 python 文檔(參見https://stackoverflow.com/a/41407651/1031191 ),您可以避免舍入問題,例如 13.94999999,因為13.94999999 == 13.95
是True
。
另一種方法是將兩個字典等中的每一個轉換為熊貓數據幀,然后使用pd.testing.assert_frame_equal()
來比較兩者。 我已經成功地使用它來比較字典列表。
以前的答案通常不適用於涉及字典的結構,但這個答案應該。 我還沒有在高度嵌套的結構上對此進行詳盡的測試,但想象一下熊貓會正確處理它們。
為了說明這一點,我將使用您的 dict 示例數據,因為其他方法不適用於 dicts。 你的字典是:
x, y = 0.1234567890, 0.1234567891
{1: x, 2: x, 3: x}, {1: y, 2: y, 3: y}
然后我們可以這樣做:
pd.testing.assert_frame_equal(
pd.DataFrame.from_dict({1: x, 2: x, 3: x}, orient='index') ,
pd.DataFrame.from_dict({1: y, 2: y, 3: y}, orient='index') )
這不會引起錯誤,這意味着它們等於一定程度的精度。
但是,如果我們要這樣做
pd.testing.assert_frame_equal(
pd.DataFrame.from_dict({1: x, 2: x, 3: x}, orient='index') ,
pd.DataFrame.from_dict({1: y, 2: y, 3: y + 1}, orient='index') ) #add 1 to last value
然后我們將收到以下信息性消息:
AssertionError: DataFrame.iloc[:, 0] (column name="0") are different
DataFrame.iloc[:, 0] (column name="0") values are different (33.33333 %)
[index]: [1, 2, 3]
[left]: [0.123456789, 0.123456789, 0.123456789]
[right]: [0.1234567891, 0.1234567891, 1.1234567891]
有關更多詳細信息,請參閱pd.testing.assert_frame_equal 文檔,特別是參數check_exact
、 rtol
、 atol
以獲取有關如何指定所需的相對精度或實際精度的信息。
a = {i*10 : {1:1.1,2:2.1} for i in range(4)}
b = {i*10 : {1:1.1000001,2:2.100001} for i in range(4)}
# a = {0: {1: 1.1, 2: 2.1}, 10: {1: 1.1, 2: 2.1}, 20: {1: 1.1, 2: 2.1}, 30: {1: 1.1, 2: 2.1}}
# b = {0: {1: 1.1000001, 2: 2.100001}, 10: {1: 1.1000001, 2: 2.100001}, 20: {1: 1.1000001, 2: 2.100001}, 30: {1: 1.1000001, 2: 2.100001}}
然后做
pd.testing.assert_frame_equal( pd.DataFrame(a), pd.DataFrame(b) )
- 它不會引發錯誤:所有值都相當相似。 但是,如果我們改變一個值,例如
b[30][2] += 1
# b = {0: {1: 1.1000001, 2: 2.1000001}, 10: {1: 1.1000001, 2: 2.1000001}, 20: {1: 1.1000001, 2: 2.1000001}, 30: {1: 1.1000001, 2: 3.1000001}}
然后運行相同的測試,我們得到以下明確的錯誤消息:
AssertionError: DataFrame.iloc[:, 3] (column name="30") are different
DataFrame.iloc[:, 3] (column name="30") values are different (50.0 %)
[index]: [1, 2]
[left]: [1.1, 2.1]
[right]: [1.1000001, 3.1000001]
您還可以遞歸調用已經存在的unittest.assertAlmostEquals()
並通過向您的單元測試添加一個方法來跟蹤您正在比較的元素。
例如,對於列表列表和浮點數元組列表:
def assertListAlmostEqual(self, first, second, delta=None, context=None):
"""Asserts lists of lists or tuples to check if they compare and
shows which element is wrong when comparing two lists
"""
self.assertEqual(len(first), len(second), msg="List have different length")
context = [first, second] if context is None else context
for i in range(0, len(first)):
if isinstance(first[0], tuple):
context.append(i)
self.assertListAlmostEqual(first[i], second[i], delta, context=context)
if isinstance(first[0], list):
context.append(i)
self.assertListAlmostEqual(first[i], second[i], delta, context=context)
elif isinstance(first[0], float):
msg = "Difference in \n{} and \n{}\nFaulty element index={}".format(context[0], context[1], context[2:]+[i]) \
if context is not None else None
self.assertAlmostEqual(first[i], second[i], delta, msg=msg)
輸出類似:
line 23, in assertListAlmostEqual
self.assertAlmostEqual(first[i], second[i], delta, msg=msg)
AssertionError: 5.0 != 6.0 within 7 places (1.0 difference) : Difference in
[(0.0, 5.0), (8.0, 2.0), (10.0, 1.999999), (11.0, 1.9999989090909092)] and
[(0.0, 6.0), (8.0, 2.0), (10.0, 1.999999), (11.0, 1.9999989)]
Faulty element index=[0, 1]
另一種方法是將您的數據轉換為可比較的形式,例如將每個浮點數轉換為具有固定精度的字符串。
def comparable(data):
"""Converts `data` to a comparable structure by converting any floats to a string with fixed precision."""
if isinstance(data, (int, str)):
return data
if isinstance(data, float):
return '{:.4f}'.format(data)
if isinstance(data, list):
return [comparable(el) for el in data]
if isinstance(data, tuple):
return tuple([comparable(el) for el in data])
if isinstance(data, dict):
return {k: comparable(v) for k, v in data.items()}
那么你也能:
self.assertEquals(comparable(value1), comparable(value2))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.