[英]Does pytest support the use of function factories in test files?
示例test.py
文件:
import torch
def one():
return torch.tensor(0.0132005215)
def two():
return torch.tensor(4.4345855713e-05)
def three():
return torch.tensor(7.1525573730e-07)
def test_method(method, expected_value):
value = method()
assert(torch.isclose(value, expected_value))
def test_one():
test_method(one, torch.tensor(0.0132005215))
def test_two():
test_method(two, torch.tensor(4.4345855713e-05))
def test_three():
test_method(three, torch.tensor(7.1525573730e-07))
# test_method(three, torch.tensor(1.0))
if __name__ == '__main__':
test_one()
test_two()
test_three()
基本上,我有几个要测试的函数(这里称为one
、 two
和three
),它们都具有相同的签名但内部结构不同。 因此,我没有编写函数test_one()
、 test_two()
等并因此复制代码,而是编写了一个“函数工厂”(这是正确的术语吗?) test_method
,它将 function 作为输入,预期结果并返回assert
命令的结果。
如您所见,现在测试是手动执行的:我运行脚本test.py
,查看屏幕,如果没有打印Assertion error
,我很高兴。 当然,我想通过使用pytest
来改进这一点,因为有人告诉我它是最简单和最常用的 Python 测试框架之一。 问题是,通过查看pytest
文档,我得到的印象是pytest
将尝试运行名称以test_
开头的所有函数。 当然,测试test_method
本身没有任何意义。 你能帮我重构这个测试脚本,以便我可以用pytest
运行它吗?
在 pytest 中,您可以使用测试参数化来实现这一点。 在您的情况下,您必须为测试提供不同的参数:
import pytest
@pytest.mark.parametrize("method, expected_value",
[(one, 0.0132005215),
(two, 4.4345855713e-05),
(three, 7.1525573730e-07)])
def test_method(method, expected_value):
value = method()
assert(torch.isclose(value, expected_value))
If you run python -m pytest -rA
(see the documentation for output options), you will get the output of three tests, something like:
======================================================= PASSES ========================================================
=============================================== short test summary info ===============================================
PASSED test.py::test_method[one-0.0132005215]
PASSED test.py::test_method[two-4.4345855713e-05]
PASSED test.py::test_method[three-7.152557373e-07]
================================================== 3 passed in 0.07s ==================================================
如果您不喜欢灯具名称,可以调整它们:
@pytest.mark.parametrize("method, expected_value",
[(one, 0.0132005215),
(two, 4.4345855713e-05),
(three, 7.1525573730e-07),
],
ids=["one", "two", "three"])
...
这给了你:
=============================================== short test summary info ===============================================
PASSED test.py::test_method[one]
PASSED test.py::test_method[two]
PASSED test.py::test_method[three]
================================================== 3 passed in 0.06s ==================================================
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.