繁体   English   中英

Python:无论从哪里导入模块,都可以模拟修补模块

[英]Python: mock patch a module wherever it is imported from

我需要确保运行单元测试不会触发调用一个沉重的外部世界函数,比如这个:

# bigbad.py
def request(param):
    return 'I searched the whole Internet for "{}"'.format(param)

多个模块使用这个函数 (bigbad.request) 并且它们以不同的方式导入它(在现实生活中它也可能从外部库导入)。 假设有两个模块 a 和 b,其中 b 依赖于 a 并且都使用该函数:

# a.py, from...import
from bigbad import request

def routine_a():
    return request('a')

# b.py, imports directly
import a
import bigbad

def routine_b():
    resp_a = a.routine_a()
    return 'resp_a: {}, resp_b=request(resp_a): {}'.format(resp_a, bigbad.request(resp_a))

有没有办法确保 bigbad.request 永远不会被调用? 此代码仅模拟其中一个导入:

# test_b.py
import unittest
from unittest import mock
import b

with mock.patch('bigbad.request') as mock_request:
    mock_request.return_value = 'mocked'
    print(b.routine_b())

显然我可以重构 b 并更改导入,但是这样我不能保证在未来的开发过程中有人不会破坏这个规定。 我相信测试应该测试行为而不是实现细节。

import bigbad
bigbad.request = # some dummy function

只要它在运行/导入from bigbad import request任何模块之前运行,它就会起作用。 也就是说,只要他们跑了,他们就会收到哑函数。

# a.py, from...import
from bigbad import request

为确保永远不会调用原始request ,您必须修补导入引用的所有位置:

import mock
with mock.patch('a.request', return_value='mocked') as mock_request:
    ...

这很乏味,所以如果可能的话,不要在代码中使用from bigbad import request ,而是使用import bigbad; bigbad.request import bigbad; bigbad.request

另一个解决方案:如果可能,更改bigbad.py

# bigbad.py
def _request(param):
    return 'I searched the whole Internet for "{}"'.format(param)


def request(param):
    return _request(param)

然后,即使某些代码from bigbad import request ,您也可以with mock.patch('bigbad._request', return_value='mocked') as mock_request:

对于将来遇到这个问题的任何人,我编写了一个函数来修补给定符号的所有导入。

此函数为给定符号(整个模块、特定函数或任何其他对象)的每次导入返回修补程序列表。 然后可以在测试夹具的设置/拆卸区域启动/停止这些修补程序(请参阅文档字符串以获取示例)。

这个怎么运作:

  • 遍历sys.modules每个当前可见的模块
  • 如果模块的名称以match_prefix (可选)开头并且不包含skip_substring (可选),则遍历模块中的每个局部
  • 如果本地是target_symbol ,为它创建一个修补程序,在它导入的模块本地

我建议使用像skip_substring='test'这样的参数,这样你就不会修补测试套件导入的东西。

from typing import Any, Optional
import unittest.mock as mock
import sys

def patch_all_symbol_imports(
        target_symbol: Any, match_prefix: Optional[str] = None,
        skip_substring: Optional[str] = None
):
    """
    Iterate through every visible module (in sys.modules) that starts with
    `match_prefix` to find imports of `target_symbol` and return a list
    of patchers for each import.

    This is helpful when you want to patch a module, function, or object
    everywhere in your project's code, even when it is imported with an alias.

    Example:

    ::

        import datetime

        # Setup
        patchers = patch_all_symbol_imports(datetime, 'my_project.', 'test')
        for patcher in patchers:
            mock_dt = patcher.start()
            # Do stuff with the mock

        # Teardown
        for patcher in patchers:
            patcher.stop()

    :param target_symbol: the symbol to search for imports of (may be a module,
        a function, or some other object)
    :param match_prefix: if not None, only search for imports in
        modules that begin with this string
    :param skip_substring: if not None, skip any module that contains this
        substring (e.g. 'test' to skip unit test modules)
    :return: a list of patchers for each import of the target symbol
    """

    patchers = []

    # Iterate through all currently imported modules
    # Make a copy in case it changes
    for module in list(sys.modules.values()):
        name_matches = (
                match_prefix is None
                or module.__name__.startswith(match_prefix)
        )
        should_skip = (
            skip_substring is not None and skip_substring in module.__name__
        )
        if not name_matches or should_skip:
            continue

        # Iterate through this module's locals
        # Again, make a copy
        for local_name, local in list(module.__dict__.items()):
            if local is target_symbol:
                # Patch this symbol local to the module
                patchers.append(mock.patch(
                    f'{module.__name__}.{local_name}', autospec=True
                ))

    return patchers

对于这个问题,可以使用以下代码:

from bigbad import request

patchers = patch_all_symbol_imports(request, skip_substring='test')
for patcher in patchers:
    mock_request = patcher.start()
    mock_request.return_value = 'mocked'

print(b.routine_b())

for patcher in patchers:
    patcher.stop()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM