简体   繁体   English

Python 修补(使用 unittest.mock) function 后已从模块导入

[英]Python patching (with unittest.mock) function after already imported from module

Consider that I have 2 files:考虑到我有 2 个文件:

some_module.py

def get_value_1():
    return 1


def print_value_1():
    print(get_value_1())

and main file main.py和主文件main.py

from unittest.mock import patch
import some_module
from some_module import print_value_1, get_value_1


def mocked_value():
    return 2


if __name__ == '__main__':
    print_value_1()  # prints 1
    with patch.object(some_module, 'get_value_1', mocked_value):
        print_value_1()  # prints 2
        print(some_module.get_value_1())  # prints 2
        print(get_value_1())  # prints 1 - DESIRABLE RESULT IS TO PRINT ALSO 2

As you can see because I explicitly imported the get_value_1 function, the patch is not working on it.如您所见,因为我显式导入了get_value_1 function,所以补丁不起作用。 I understand basically why, that's because it uses a reference and the reference is imported before the main ran (checked it with calling id() on each invoked function and saw the addresses).我基本上理解为什么,那是因为它使用了引用并且引用是在主运行之前导入的(通过在每个调用的 function 上调用id()并查看地址来检查它)。 Can I somehow hijack also the imported reference?我可以以某种方式劫持导入的参考吗?

(It won't be enough to patch it only in main.py, I want it to be patched all over the project, so for example in some other some_other_module.py there will be: from some_module import get_value_1 and when I call get_value_1() it will call the patched function and return the value 2) (仅在 main.py 中对其进行修补是不够的,我希望在整个项目中对其进行修补,因此例如在其他一些some_other_module.py中将有: from some_module import get_value_1和当我调用get_value_1()它将调用已修补的 function 并返回值 2)

If using patch or patch.object , there is no way around patching every reference of the module.如果使用patchpatch.object ,则无法修补模块的每个引用。 In your case this would be for example:在您的情况下,例如:

if __name__ == '__main__':
    print_value_1()
    with patch.object(some_module, 'get_value_1', mocked_value):
        with patch.object(sys.modules[__name__], 'get_value_1', mocked_value):
            print_value_1()
            print(some_module.get_value_1()) 
            print(get_value_1())

Depending on how your app structure looks, you could iterate over all modules that reference the function to be patched, eg:根据您的应用程序结构的外观,您可以遍历所有引用要修补的 function 的模块,例如:

def get_modules():
    return (sys.modules[__name__], some_module, some_other_module)

if __name__ == '__main__':
    patches = []
    for module in get_modules():
        p = patch.object(module, 'get_value_1', mocked_value)
        p.start()
        patches.append(p)
        
    print(some_module.get_value_1())
    print(some_other_module.get_value_1())
    print(get_value_1())
    
    [p.stop() for p in patches]

If you don't have all modules to patch beforehand, you have to collect all modules you need to patch at runtime (eg get_modules gets more complicated), for example by iterating over all loaded modules, find your function to be patched by name, and mock that (assuming that at that point all modules are loaded).如果您没有预先修补所有模块,则必须收集需要在运行时修补的所有模块(例如get_modules变得更加复杂),例如通过遍历所有加载的模块,找到要按名称修补的 function,并模拟它(假设此时所有模块都已加载)。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM