How can I combine typing information in Python, using mypy, for a context manager combined with specific string literals in an overload?
I have a capture_output
function that redirects the standard output and/or standard error, which can be used as a context manager when running a unit test. It captures the stream(s), keeping them from printing to the screen, and then you can get a copy of any messages printing out during the test and compare their contents to what you expect. It remains one of the only functions that I flat out cannot figure out how to type correctly.
Solutions for only Python v3.9+ are perfectly acceptable. I'm currently running Python v3.10.4 and mypy v0.961.
"""Typing test script to combine mypy and contextmanagers."""
#%% Imports
from contextlib import contextmanager
from io import StringIO
import sys
from typing import Iterator, Literal, overload, TextIO, Tuple, Union
import unittest
#%% Functions
# @overload
# def capture_output(mode: Literal["err"]) -> Iterator[TextIO]:
# ...
# @overload
# def capture_output(mode: Literal["all"]) -> Tuple[Iterator[TextIO], Iterator[TextIO]]:
# ...
# @overload
# def capture_output(mode: Literal["out"] = ...) -> Iterator[TextIO]:
# ...
@contextmanager
def capture_output(mode: Literal["out", "err", "all"] = "out") -> Union[Iterator[TextIO], Iterator[Tuple[TextIO, TextIO]]]:
r"""
Capture the stdout and stderr streams instead of displaying to the screen.
Parameters
----------
mode : str
Mode to use when capturing output
'out' captures just sys.stdout
'err' captures just sys.stderr
'all' captures both sys.stdout and sys.stderr
Returns
-------
out : class StringIO
stdout stream output
err : class StringIO
stderr stream output
Examples
--------
>>> with capture_output() as out:
... print('Hello, World!')
>>> output = out.getvalue().strip()
>>> out.close()
>>> print(output)
Hello, World!
"""
# alias modes
capture_out = mode in {"out", "all"}
capture_err = mode in {"err", "all"}
# create new string buffers
new_out, new_err = StringIO(), StringIO()
# alias the old string buffers for restoration afterwards
old_out, old_err = sys.stdout, sys.stderr
try:
# override the system buffers with the new ones
if capture_out:
sys.stdout = new_out
if capture_err:
sys.stderr = new_err
# yield results as desired
if mode == "out":
yield sys.stdout
elif mode == "err":
yield sys.stderr
elif mode == "all":
yield sys.stdout, sys.stderr
finally:
# restore the original buffers once all results are read
sys.stdout, sys.stderr = old_out, old_err
#%% Tests
class Test_capture_output(unittest.TestCase):
r"""
Tests the capture_output function with the following cases:
capture standard output
capture standard error
capture both standard output and standard error
unknown option
"""
def test_std_out(self) -> None:
with capture_output() as out:
print("Hello, World!")
output = out.getvalue().strip()
out.close()
self.assertEqual(output, "Hello, World!")
def test_std_err(self) -> None:
with capture_output("err") as err:
print("Error Raised.", file=sys.stderr)
error = err.getvalue().strip()
err.close()
self.assertEqual(error, "Error Raised.")
def test_all(self) -> None:
with capture_output("all") as (out, err):
print("Hello, World!")
print("Error Raised.", file=sys.stderr)
output = out.getvalue().strip()
error = err.getvalue().strip()
out.close()
err.close()
self.assertEqual(output, "Hello, World!")
self.assertEqual(error, "Error Raised.")
def test_bad_value(self) -> None:
with self.assertRaises(RuntimeError):
with capture_output("bad") as (out, err): # type: ignore[arg-type]
print("Lost values") # pragma: no cover
#%% Unit test execution
if __name__ == "__main__":
unittest.main(exit=False)
Running mypy on this file gives a bunch of statements about:
"Iterable[Iterable[str]]" has no attribute "getvalue" [attr-defined]
"Iterable[Iterable[str]]" has no attribute "close" [attr-defined]
Is there any other way to type the function to let it know about the two variations?
A few issues I noticed trying to apply typing to your implementation
capture_output
, which makes it impossible to know the return type at runtime (imagine you passed in a string variable instead of a string literal)stdout
/ stderr
can be TextIO
, which doesn't support getvalue()
As an intermediate Python user, my suggestion would be to provide support for cases where stdout
and stderr
are TextIO
, and standardise your capture_output
return type or encapsulate the return type as its own class as I've done below. This passes mypy --strict
and your unit tests, but perhaps more advanced Python users would have a better solution.
"""Typing test script to combine mypy and contextmanagers."""
# %% Imports
from contextlib import contextmanager
from io import StringIO
import sys
from typing import Iterator, Literal, TextIO, Optional
import unittest
class CaptureOutputResult:
def __init__(self, stdout: Optional[StringIO | TextIO] = None, stderr: Optional[StringIO | TextIO] = None):
self.stdout = stdout
self.stderr = stderr
def close(self) -> None:
if self.stdout:
self.stdout.close()
if self.stderr:
self.stderr.close()
def get_output(self) -> str:
return CaptureOutputResult.get_std(self.stdout)
def get_error(self) -> str:
return CaptureOutputResult.get_std(self.stderr)
@staticmethod
def get_std(std: StringIO | TextIO | None) -> str:
if std is None:
return ""
if isinstance(std, StringIO):
return std.getvalue().strip()
if isinstance(std, TextIO):
# TODO: check this does what you expect
return "\n".join(std.readlines())
raise Exception(f"Unknown type {type(std)}")
@contextmanager
def capture_output(mode: Literal["out", "err", "all"] = "out") -> Iterator[CaptureOutputResult]:
r"""
Capture the stdout and stderr streams instead of displaying to the screen.
Parameters
----------
mode : str
Mode to use when capturing output
'out' captures just sys.stdout
'err' captures just sys.stderr
'all' captures both sys.stdout and sys.stderr
Returns
-------
out : class StringIO
stdout stream output
err : class StringIO
stderr stream output
Examples
--------
>>> with capture_output() as ctx:
... print('Hello, World!')
>>> output = ctx.get_output()
>>> ctx.close()
>>> print(output)
Hello, World!
"""
# alias modes
capture_out = mode in {"out", "all"}
capture_err = mode in {"err", "all"}
# create new string buffers
new_out, new_err = StringIO(), StringIO()
# alias the old string buffers for restoration afterwards
old_out, old_err = sys.stdout, sys.stderr
try:
# override the system buffers with the new ones
if capture_out:
sys.stdout = new_out
if capture_err:
sys.stderr = new_err
# yield results as desired
if mode == "out":
yield CaptureOutputResult(stdout=sys.stdout)
elif mode == "err":
yield CaptureOutputResult(stderr=sys.stderr)
elif mode == "all":
yield CaptureOutputResult(stdout=sys.stdout, stderr=sys.stderr)
finally:
# restore the original buffers once all results are read
sys.stdout, sys.stderr = old_out, old_err
# %% Tests
class Test_capture_output(unittest.TestCase):
r"""
Tests the capture_output function with the following cases:
capture standard output
capture standard error
capture both standard output and standard error
unknown option
"""
def test_std_out(self) -> None:
with capture_output() as ctx:
print("Hello, World!")
output = ctx.get_output()
ctx.close()
self.assertEqual(output, "Hello, World!")
def test_std_err(self) -> None:
with capture_output("err") as ctx:
print("Error Raised.", file=sys.stderr)
error = ctx.get_error()
ctx.close()
self.assertEqual(error, "Error Raised.")
def test_all(self) -> None:
with capture_output("all") as ctx:
print("Hello, World!")
print("Error Raised.", file=sys.stderr)
output = ctx.get_output()
error = ctx.get_error()
ctx.close()
self.assertEqual(output, "Hello, World!")
self.assertEqual(error, "Error Raised.")
def test_bad_value(self) -> None:
with self.assertRaises(RuntimeError):
with capture_output("bad") as ctx: # type: ignore[arg-type]
print("Lost values") # pragma: no cover
# %% Unit test execution
if __name__ == "__main__":
unittest.main(exit=False)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.