简体   繁体   English

Python 键入上下文管理器和字符串文字

[英]Python typing for context manager and string literal

How can I combine typing information in Python, using mypy, for a context manager combined with specific string literals in an overload?如何使用 mypy 将 Python 中的键入信息与重载中的特定字符串文字结合起来的上下文管理器?

I have a capture_output function that redirects the standard output and/or standard error, which can be used as a context manager when running a unit test.我有一个capture_output function 重定向标准 output 和/或标准错误,可以在运行单元测试时用作上下文管理器。 It captures the stream(s), keeping them from printing to the screen, and then you can get a copy of any messages printing out during the test and compare their contents to what you expect.它捕获流,阻止它们打印到屏幕上,然后您可以获得在测试期间打印出的任何消息的副本,并将其内容与您期望的内容进行比较。 It remains one of the only functions that I flat out cannot figure out how to type correctly.它仍然是我完全无法弄清楚如何正确输入的唯一功能之一。

Solutions for only Python v3.9+ are perfectly acceptable.仅适用于 Python v3.9+ 的解决方案是完全可以接受的。 I'm currently running Python v3.10.4 and mypy v0.961.我目前正在运行 Python v3.10.4 和 mypy v0.961。

"""Typing test script to combine mypy and contextmanagers."""

#%% Imports
from contextlib import contextmanager
from io import StringIO
import sys
from typing import Iterator, Literal, overload, TextIO, Tuple, Union
import unittest


#%% Functions
# @overload
# def capture_output(mode: Literal["err"]) -> Iterator[TextIO]:
#     ...


# @overload
# def capture_output(mode: Literal["all"]) -> Tuple[Iterator[TextIO], Iterator[TextIO]]:
#     ...


# @overload
# def capture_output(mode: Literal["out"] = ...) -> Iterator[TextIO]:
#     ...


@contextmanager
def capture_output(mode: Literal["out", "err", "all"] = "out") -> Union[Iterator[TextIO], Iterator[Tuple[TextIO, TextIO]]]:
    r"""
    Capture the stdout and stderr streams instead of displaying to the screen.

    Parameters
    ----------
    mode : str
        Mode to use when capturing output
            'out' captures just sys.stdout
            'err' captures just sys.stderr
            'all' captures both sys.stdout and sys.stderr

    Returns
    -------
    out : class StringIO
        stdout stream output
    err : class StringIO
        stderr stream output

    Examples
    --------
    >>> with capture_output() as out:
    ...     print('Hello, World!')
    >>> output = out.getvalue().strip()
    >>> out.close()
    >>> print(output)
    Hello, World!

    """
    # alias modes
    capture_out = mode in {"out", "all"}
    capture_err = mode in {"err", "all"}
    # create new string buffers
    new_out, new_err = StringIO(), StringIO()
    # alias the old string buffers for restoration afterwards
    old_out, old_err = sys.stdout, sys.stderr
    try:
        # override the system buffers with the new ones
        if capture_out:
            sys.stdout = new_out
        if capture_err:
            sys.stderr = new_err
        # yield results as desired
        if mode == "out":
            yield sys.stdout
        elif mode == "err":
            yield sys.stderr
        elif mode == "all":
            yield sys.stdout, sys.stderr
    finally:
        # restore the original buffers once all results are read
        sys.stdout, sys.stderr = old_out, old_err


#%% Tests
class Test_capture_output(unittest.TestCase):
    r"""
    Tests the capture_output function with the following cases:
        capture standard output
        capture standard error
        capture both standard output and standard error
        unknown option
    """

    def test_std_out(self) -> None:
        with capture_output() as out:
            print("Hello, World!")
        output = out.getvalue().strip()
        out.close()
        self.assertEqual(output, "Hello, World!")

    def test_std_err(self) -> None:
        with capture_output("err") as err:
            print("Error Raised.", file=sys.stderr)
        error = err.getvalue().strip()
        err.close()
        self.assertEqual(error, "Error Raised.")

    def test_all(self) -> None:
        with capture_output("all") as (out, err):
            print("Hello, World!")
            print("Error Raised.", file=sys.stderr)
        output = out.getvalue().strip()
        error = err.getvalue().strip()
        out.close()
        err.close()
        self.assertEqual(output, "Hello, World!")
        self.assertEqual(error, "Error Raised.")

    def test_bad_value(self) -> None:
        with self.assertRaises(RuntimeError):
            with capture_output("bad") as (out, err):  # type: ignore[arg-type]
                print("Lost values")  # pragma: no cover


#%% Unit test execution
if __name__ == "__main__":
    unittest.main(exit=False)

Running mypy on this file gives a bunch of statements about:在这个文件上运行 mypy 会给出一堆关于:

"Iterable[Iterable[str]]" has no attribute "getvalue"  [attr-defined]
"Iterable[Iterable[str]]" has no attribute "close"  [attr-defined]

Is there any other way to type the function to let it know about the two variations?有没有其他方法可以输入 function 让它知道这两种变化?

A few issues I noticed trying to apply typing to your implementation我注意到尝试将类型应用于您的实现的一些问题

  • A string determines the return type of capture_output , which makes it impossible to know the return type at runtime (imagine you passed in a string variable instead of a string literal)字符串确定capture_output的返回类型,这使得在运行时无法知道返回类型(假设您传入的是字符串变量而不是字符串文字)
  • stdout / stderr can be TextIO , which doesn't support getvalue() stdout / stderr可以是TextIO ,不支持getvalue()

As an intermediate Python user, my suggestion would be to provide support for cases where stdout and stderr are TextIO , and standardise your capture_output return type or encapsulate the return type as its own class as I've done below.作为中级 Python 用户,我的建议是为stdoutstderrTextIO的情况提供支持,并将您的capture_output返回类型标准化或将返回类型封装为自己的 class ,如下所示。 This passes mypy --strict and your unit tests, but perhaps more advanced Python users would have a better solution.这通过mypy --strict和您的单元测试,但也许更高级的 Python 用户会有更好的解决方案。

"""Typing test script to combine mypy and contextmanagers."""

# %% Imports
from contextlib import contextmanager
from io import StringIO
import sys
from typing import Iterator, Literal, TextIO, Optional
import unittest


class CaptureOutputResult:

    def __init__(self, stdout: Optional[StringIO | TextIO] = None, stderr: Optional[StringIO | TextIO] = None):
        self.stdout = stdout
        self.stderr = stderr

    def close(self) -> None:
        if self.stdout:
            self.stdout.close()

        if self.stderr:
            self.stderr.close()

    def get_output(self) -> str:
        return CaptureOutputResult.get_std(self.stdout)

    def get_error(self) -> str:
        return CaptureOutputResult.get_std(self.stderr)

    @staticmethod
    def get_std(std: StringIO | TextIO | None) -> str:
        if std is None:
            return ""

        if isinstance(std, StringIO):
            return std.getvalue().strip()

        if isinstance(std, TextIO):
            # TODO: check this does what you expect
            return "\n".join(std.readlines())

        raise Exception(f"Unknown type {type(std)}")


@contextmanager
def capture_output(mode: Literal["out", "err", "all"] = "out") -> Iterator[CaptureOutputResult]:
    r"""
    Capture the stdout and stderr streams instead of displaying to the screen.

    Parameters
    ----------
    mode : str
        Mode to use when capturing output
            'out' captures just sys.stdout
            'err' captures just sys.stderr
            'all' captures both sys.stdout and sys.stderr

    Returns
    -------
    out : class StringIO
        stdout stream output
    err : class StringIO
        stderr stream output

    Examples
    --------
    >>> with capture_output() as ctx:
    ...     print('Hello, World!')
    >>> output = ctx.get_output()
    >>> ctx.close()
    >>> print(output)
    Hello, World!

    """
    # alias modes
    capture_out = mode in {"out", "all"}
    capture_err = mode in {"err", "all"}
    # create new string buffers
    new_out, new_err = StringIO(), StringIO()
    # alias the old string buffers for restoration afterwards
    old_out, old_err = sys.stdout, sys.stderr
    try:
        # override the system buffers with the new ones
        if capture_out:
            sys.stdout = new_out
        if capture_err:
            sys.stderr = new_err
        # yield results as desired
        if mode == "out":
            yield CaptureOutputResult(stdout=sys.stdout)
        elif mode == "err":
            yield CaptureOutputResult(stderr=sys.stderr)
        elif mode == "all":
            yield CaptureOutputResult(stdout=sys.stdout, stderr=sys.stderr)
    finally:
        # restore the original buffers once all results are read
        sys.stdout, sys.stderr = old_out, old_err


# %% Tests
class Test_capture_output(unittest.TestCase):
    r"""
    Tests the capture_output function with the following cases:
        capture standard output
        capture standard error
        capture both standard output and standard error
        unknown option
    """

    def test_std_out(self) -> None:
        with capture_output() as ctx:
            print("Hello, World!")
        output = ctx.get_output()
        ctx.close()
        self.assertEqual(output, "Hello, World!")

    def test_std_err(self) -> None:
        with capture_output("err") as ctx:
            print("Error Raised.", file=sys.stderr)
        error = ctx.get_error()
        ctx.close()
        self.assertEqual(error, "Error Raised.")

    def test_all(self) -> None:
        with capture_output("all") as ctx:
            print("Hello, World!")
            print("Error Raised.", file=sys.stderr)

        output = ctx.get_output()
        error = ctx.get_error()
        ctx.close()
        self.assertEqual(output, "Hello, World!")
        self.assertEqual(error, "Error Raised.")

    def test_bad_value(self) -> None:
        with self.assertRaises(RuntimeError):
            with capture_output("bad") as ctx:  # type: ignore[arg-type]
                print("Lost values")  # pragma: no cover


# %% Unit test execution
if __name__ == "__main__":
    unittest.main(exit=False)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM