[英]Python typing for context manager and string literal
How can I combine typing information in Python, using mypy, for a context manager combined with specific string literals in an overload?如何使用 mypy 将 Python 中的键入信息与重载中的特定字符串文字结合起来的上下文管理器?
I have a capture_output
function that redirects the standard output and/or standard error, which can be used as a context manager when running a unit test.我有一个
capture_output
function 重定向标准 output 和/或标准错误,可以在运行单元测试时用作上下文管理器。 It captures the stream(s), keeping them from printing to the screen, and then you can get a copy of any messages printing out during the test and compare their contents to what you expect.它捕获流,阻止它们打印到屏幕上,然后您可以获得在测试期间打印出的任何消息的副本,并将其内容与您期望的内容进行比较。 It remains one of the only functions that I flat out cannot figure out how to type correctly.
它仍然是我完全无法弄清楚如何正确输入的唯一功能之一。
Solutions for only Python v3.9+ are perfectly acceptable.仅适用于 Python v3.9+ 的解决方案是完全可以接受的。 I'm currently running Python v3.10.4 and mypy v0.961.
我目前正在运行 Python v3.10.4 和 mypy v0.961。
"""Typing test script to combine mypy and contextmanagers."""
#%% Imports
from contextlib import contextmanager
from io import StringIO
import sys
from typing import Iterator, Literal, overload, TextIO, Tuple, Union
import unittest
#%% Functions
# @overload
# def capture_output(mode: Literal["err"]) -> Iterator[TextIO]:
# ...
# @overload
# def capture_output(mode: Literal["all"]) -> Tuple[Iterator[TextIO], Iterator[TextIO]]:
# ...
# @overload
# def capture_output(mode: Literal["out"] = ...) -> Iterator[TextIO]:
# ...
@contextmanager
def capture_output(mode: Literal["out", "err", "all"] = "out") -> Union[Iterator[TextIO], Iterator[Tuple[TextIO, TextIO]]]:
r"""
Capture the stdout and stderr streams instead of displaying to the screen.
Parameters
----------
mode : str
Mode to use when capturing output
'out' captures just sys.stdout
'err' captures just sys.stderr
'all' captures both sys.stdout and sys.stderr
Returns
-------
out : class StringIO
stdout stream output
err : class StringIO
stderr stream output
Examples
--------
>>> with capture_output() as out:
... print('Hello, World!')
>>> output = out.getvalue().strip()
>>> out.close()
>>> print(output)
Hello, World!
"""
# alias modes
capture_out = mode in {"out", "all"}
capture_err = mode in {"err", "all"}
# create new string buffers
new_out, new_err = StringIO(), StringIO()
# alias the old string buffers for restoration afterwards
old_out, old_err = sys.stdout, sys.stderr
try:
# override the system buffers with the new ones
if capture_out:
sys.stdout = new_out
if capture_err:
sys.stderr = new_err
# yield results as desired
if mode == "out":
yield sys.stdout
elif mode == "err":
yield sys.stderr
elif mode == "all":
yield sys.stdout, sys.stderr
finally:
# restore the original buffers once all results are read
sys.stdout, sys.stderr = old_out, old_err
#%% Tests
class Test_capture_output(unittest.TestCase):
r"""
Tests the capture_output function with the following cases:
capture standard output
capture standard error
capture both standard output and standard error
unknown option
"""
def test_std_out(self) -> None:
with capture_output() as out:
print("Hello, World!")
output = out.getvalue().strip()
out.close()
self.assertEqual(output, "Hello, World!")
def test_std_err(self) -> None:
with capture_output("err") as err:
print("Error Raised.", file=sys.stderr)
error = err.getvalue().strip()
err.close()
self.assertEqual(error, "Error Raised.")
def test_all(self) -> None:
with capture_output("all") as (out, err):
print("Hello, World!")
print("Error Raised.", file=sys.stderr)
output = out.getvalue().strip()
error = err.getvalue().strip()
out.close()
err.close()
self.assertEqual(output, "Hello, World!")
self.assertEqual(error, "Error Raised.")
def test_bad_value(self) -> None:
with self.assertRaises(RuntimeError):
with capture_output("bad") as (out, err): # type: ignore[arg-type]
print("Lost values") # pragma: no cover
#%% Unit test execution
if __name__ == "__main__":
unittest.main(exit=False)
Running mypy on this file gives a bunch of statements about:在这个文件上运行 mypy 会给出一堆关于:
"Iterable[Iterable[str]]" has no attribute "getvalue" [attr-defined]
"Iterable[Iterable[str]]" has no attribute "close" [attr-defined]
Is there any other way to type the function to let it know about the two variations?有没有其他方法可以输入 function 让它知道这两种变化?
A few issues I noticed trying to apply typing to your implementation我注意到尝试将类型应用于您的实现的一些问题
capture_output
, which makes it impossible to know the return type at runtime (imagine you passed in a string variable instead of a string literal)capture_output
的返回类型,这使得在运行时无法知道返回类型(假设您传入的是字符串变量而不是字符串文字)stdout
/ stderr
can be TextIO
, which doesn't support getvalue()
stdout
/ stderr
可以是TextIO
,不支持getvalue()
As an intermediate Python user, my suggestion would be to provide support for cases where stdout
and stderr
are TextIO
, and standardise your capture_output
return type or encapsulate the return type as its own class as I've done below.作为中级 Python 用户,我的建议是为
stdout
和stderr
为TextIO
的情况提供支持,并将您的capture_output
返回类型标准化或将返回类型封装为自己的 class ,如下所示。 This passes mypy --strict
and your unit tests, but perhaps more advanced Python users would have a better solution.这通过
mypy --strict
和您的单元测试,但也许更高级的 Python 用户会有更好的解决方案。
"""Typing test script to combine mypy and contextmanagers."""
# %% Imports
from contextlib import contextmanager
from io import StringIO
import sys
from typing import Iterator, Literal, TextIO, Optional
import unittest
class CaptureOutputResult:
def __init__(self, stdout: Optional[StringIO | TextIO] = None, stderr: Optional[StringIO | TextIO] = None):
self.stdout = stdout
self.stderr = stderr
def close(self) -> None:
if self.stdout:
self.stdout.close()
if self.stderr:
self.stderr.close()
def get_output(self) -> str:
return CaptureOutputResult.get_std(self.stdout)
def get_error(self) -> str:
return CaptureOutputResult.get_std(self.stderr)
@staticmethod
def get_std(std: StringIO | TextIO | None) -> str:
if std is None:
return ""
if isinstance(std, StringIO):
return std.getvalue().strip()
if isinstance(std, TextIO):
# TODO: check this does what you expect
return "\n".join(std.readlines())
raise Exception(f"Unknown type {type(std)}")
@contextmanager
def capture_output(mode: Literal["out", "err", "all"] = "out") -> Iterator[CaptureOutputResult]:
r"""
Capture the stdout and stderr streams instead of displaying to the screen.
Parameters
----------
mode : str
Mode to use when capturing output
'out' captures just sys.stdout
'err' captures just sys.stderr
'all' captures both sys.stdout and sys.stderr
Returns
-------
out : class StringIO
stdout stream output
err : class StringIO
stderr stream output
Examples
--------
>>> with capture_output() as ctx:
... print('Hello, World!')
>>> output = ctx.get_output()
>>> ctx.close()
>>> print(output)
Hello, World!
"""
# alias modes
capture_out = mode in {"out", "all"}
capture_err = mode in {"err", "all"}
# create new string buffers
new_out, new_err = StringIO(), StringIO()
# alias the old string buffers for restoration afterwards
old_out, old_err = sys.stdout, sys.stderr
try:
# override the system buffers with the new ones
if capture_out:
sys.stdout = new_out
if capture_err:
sys.stderr = new_err
# yield results as desired
if mode == "out":
yield CaptureOutputResult(stdout=sys.stdout)
elif mode == "err":
yield CaptureOutputResult(stderr=sys.stderr)
elif mode == "all":
yield CaptureOutputResult(stdout=sys.stdout, stderr=sys.stderr)
finally:
# restore the original buffers once all results are read
sys.stdout, sys.stderr = old_out, old_err
# %% Tests
class Test_capture_output(unittest.TestCase):
r"""
Tests the capture_output function with the following cases:
capture standard output
capture standard error
capture both standard output and standard error
unknown option
"""
def test_std_out(self) -> None:
with capture_output() as ctx:
print("Hello, World!")
output = ctx.get_output()
ctx.close()
self.assertEqual(output, "Hello, World!")
def test_std_err(self) -> None:
with capture_output("err") as ctx:
print("Error Raised.", file=sys.stderr)
error = ctx.get_error()
ctx.close()
self.assertEqual(error, "Error Raised.")
def test_all(self) -> None:
with capture_output("all") as ctx:
print("Hello, World!")
print("Error Raised.", file=sys.stderr)
output = ctx.get_output()
error = ctx.get_error()
ctx.close()
self.assertEqual(output, "Hello, World!")
self.assertEqual(error, "Error Raised.")
def test_bad_value(self) -> None:
with self.assertRaises(RuntimeError):
with capture_output("bad") as ctx: # type: ignore[arg-type]
print("Lost values") # pragma: no cover
# %% Unit test execution
if __name__ == "__main__":
unittest.main(exit=False)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.