简体   繁体   中英

Python typing for context manager and string literal

How can I combine typing information in Python, using mypy, for a context manager combined with specific string literals in an overload?

I have a capture_output function that redirects the standard output and/or standard error, which can be used as a context manager when running a unit test. It captures the stream(s), keeping them from printing to the screen, and then you can get a copy of any messages printing out during the test and compare their contents to what you expect. It remains one of the only functions that I flat out cannot figure out how to type correctly.

Solutions for only Python v3.9+ are perfectly acceptable. I'm currently running Python v3.10.4 and mypy v0.961.

"""Typing test script to combine mypy and contextmanagers."""

#%% Imports
from contextlib import contextmanager
from io import StringIO
import sys
from typing import Iterator, Literal, overload, TextIO, Tuple, Union
import unittest


#%% Functions
# @overload
# def capture_output(mode: Literal["err"]) -> Iterator[TextIO]:
#     ...


# @overload
# def capture_output(mode: Literal["all"]) -> Tuple[Iterator[TextIO], Iterator[TextIO]]:
#     ...


# @overload
# def capture_output(mode: Literal["out"] = ...) -> Iterator[TextIO]:
#     ...


@contextmanager
def capture_output(mode: Literal["out", "err", "all"] = "out") -> Union[Iterator[TextIO], Iterator[Tuple[TextIO, TextIO]]]:
    r"""
    Capture the stdout and stderr streams instead of displaying to the screen.

    Parameters
    ----------
    mode : str
        Mode to use when capturing output
            'out' captures just sys.stdout
            'err' captures just sys.stderr
            'all' captures both sys.stdout and sys.stderr

    Returns
    -------
    out : class StringIO
        stdout stream output
    err : class StringIO
        stderr stream output

    Examples
    --------
    >>> with capture_output() as out:
    ...     print('Hello, World!')
    >>> output = out.getvalue().strip()
    >>> out.close()
    >>> print(output)
    Hello, World!

    """
    # alias modes
    capture_out = mode in {"out", "all"}
    capture_err = mode in {"err", "all"}
    # create new string buffers
    new_out, new_err = StringIO(), StringIO()
    # alias the old string buffers for restoration afterwards
    old_out, old_err = sys.stdout, sys.stderr
    try:
        # override the system buffers with the new ones
        if capture_out:
            sys.stdout = new_out
        if capture_err:
            sys.stderr = new_err
        # yield results as desired
        if mode == "out":
            yield sys.stdout
        elif mode == "err":
            yield sys.stderr
        elif mode == "all":
            yield sys.stdout, sys.stderr
    finally:
        # restore the original buffers once all results are read
        sys.stdout, sys.stderr = old_out, old_err


#%% Tests
class Test_capture_output(unittest.TestCase):
    r"""
    Tests the capture_output function with the following cases:
        capture standard output
        capture standard error
        capture both standard output and standard error
        unknown option
    """

    def test_std_out(self) -> None:
        with capture_output() as out:
            print("Hello, World!")
        output = out.getvalue().strip()
        out.close()
        self.assertEqual(output, "Hello, World!")

    def test_std_err(self) -> None:
        with capture_output("err") as err:
            print("Error Raised.", file=sys.stderr)
        error = err.getvalue().strip()
        err.close()
        self.assertEqual(error, "Error Raised.")

    def test_all(self) -> None:
        with capture_output("all") as (out, err):
            print("Hello, World!")
            print("Error Raised.", file=sys.stderr)
        output = out.getvalue().strip()
        error = err.getvalue().strip()
        out.close()
        err.close()
        self.assertEqual(output, "Hello, World!")
        self.assertEqual(error, "Error Raised.")

    def test_bad_value(self) -> None:
        with self.assertRaises(RuntimeError):
            with capture_output("bad") as (out, err):  # type: ignore[arg-type]
                print("Lost values")  # pragma: no cover


#%% Unit test execution
if __name__ == "__main__":
    unittest.main(exit=False)

Running mypy on this file gives a bunch of statements about:

"Iterable[Iterable[str]]" has no attribute "getvalue"  [attr-defined]
"Iterable[Iterable[str]]" has no attribute "close"  [attr-defined]

Is there any other way to type the function to let it know about the two variations?

A few issues I noticed trying to apply typing to your implementation

  • A string determines the return type of capture_output , which makes it impossible to know the return type at runtime (imagine you passed in a string variable instead of a string literal)
  • stdout / stderr can be TextIO , which doesn't support getvalue()

As an intermediate Python user, my suggestion would be to provide support for cases where stdout and stderr are TextIO , and standardise your capture_output return type or encapsulate the return type as its own class as I've done below. This passes mypy --strict and your unit tests, but perhaps more advanced Python users would have a better solution.

"""Typing test script to combine mypy and contextmanagers."""

# %% Imports
from contextlib import contextmanager
from io import StringIO
import sys
from typing import Iterator, Literal, TextIO, Optional
import unittest


class CaptureOutputResult:

    def __init__(self, stdout: Optional[StringIO | TextIO] = None, stderr: Optional[StringIO | TextIO] = None):
        self.stdout = stdout
        self.stderr = stderr

    def close(self) -> None:
        if self.stdout:
            self.stdout.close()

        if self.stderr:
            self.stderr.close()

    def get_output(self) -> str:
        return CaptureOutputResult.get_std(self.stdout)

    def get_error(self) -> str:
        return CaptureOutputResult.get_std(self.stderr)

    @staticmethod
    def get_std(std: StringIO | TextIO | None) -> str:
        if std is None:
            return ""

        if isinstance(std, StringIO):
            return std.getvalue().strip()

        if isinstance(std, TextIO):
            # TODO: check this does what you expect
            return "\n".join(std.readlines())

        raise Exception(f"Unknown type {type(std)}")


@contextmanager
def capture_output(mode: Literal["out", "err", "all"] = "out") -> Iterator[CaptureOutputResult]:
    r"""
    Capture the stdout and stderr streams instead of displaying to the screen.

    Parameters
    ----------
    mode : str
        Mode to use when capturing output
            'out' captures just sys.stdout
            'err' captures just sys.stderr
            'all' captures both sys.stdout and sys.stderr

    Returns
    -------
    out : class StringIO
        stdout stream output
    err : class StringIO
        stderr stream output

    Examples
    --------
    >>> with capture_output() as ctx:
    ...     print('Hello, World!')
    >>> output = ctx.get_output()
    >>> ctx.close()
    >>> print(output)
    Hello, World!

    """
    # alias modes
    capture_out = mode in {"out", "all"}
    capture_err = mode in {"err", "all"}
    # create new string buffers
    new_out, new_err = StringIO(), StringIO()
    # alias the old string buffers for restoration afterwards
    old_out, old_err = sys.stdout, sys.stderr
    try:
        # override the system buffers with the new ones
        if capture_out:
            sys.stdout = new_out
        if capture_err:
            sys.stderr = new_err
        # yield results as desired
        if mode == "out":
            yield CaptureOutputResult(stdout=sys.stdout)
        elif mode == "err":
            yield CaptureOutputResult(stderr=sys.stderr)
        elif mode == "all":
            yield CaptureOutputResult(stdout=sys.stdout, stderr=sys.stderr)
    finally:
        # restore the original buffers once all results are read
        sys.stdout, sys.stderr = old_out, old_err


# %% Tests
class Test_capture_output(unittest.TestCase):
    r"""
    Tests the capture_output function with the following cases:
        capture standard output
        capture standard error
        capture both standard output and standard error
        unknown option
    """

    def test_std_out(self) -> None:
        with capture_output() as ctx:
            print("Hello, World!")
        output = ctx.get_output()
        ctx.close()
        self.assertEqual(output, "Hello, World!")

    def test_std_err(self) -> None:
        with capture_output("err") as ctx:
            print("Error Raised.", file=sys.stderr)
        error = ctx.get_error()
        ctx.close()
        self.assertEqual(error, "Error Raised.")

    def test_all(self) -> None:
        with capture_output("all") as ctx:
            print("Hello, World!")
            print("Error Raised.", file=sys.stderr)

        output = ctx.get_output()
        error = ctx.get_error()
        ctx.close()
        self.assertEqual(output, "Hello, World!")
        self.assertEqual(error, "Error Raised.")

    def test_bad_value(self) -> None:
        with self.assertRaises(RuntimeError):
            with capture_output("bad") as ctx:  # type: ignore[arg-type]
                print("Lost values")  # pragma: no cover


# %% Unit test execution
if __name__ == "__main__":
    unittest.main(exit=False)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM