简体   繁体   中英

Using patch to to mock a function (as opposed to a method)

I want to do something like the following example (found here )

>>> with patch.object(ProductionClass, 'method', return_value=None) as mock_method:
...     thing = ProductionClass()
...     thing.method(1, 2, 3)

However this is patching the method called method on ProductionClass . I want to patch a generic function within a context. Ideally something looking like...

with path.something(my_fn, return_value=my_return) as mock_function:
    do_some_other_fn()

my_fn is called deep within do_some_other_fn and therefore is difficult to mock out directly. This seems like it should be straight forward but I can't find the right syntax

EDIT In the module that do_some_other_fn lives I import my_fn like followings

from my_module import my_fn

So I need a way to be able to tell mock to patch that from outside the module. Is this possible?

EDIT 2 I think this makes it more clear what I am looking for

This works but is not ideal:

import my_module
with patch('my_module.fn', return_value='hello') as patch_context:
    x = my_module.fn()
    # x now contains 'hello'

However I would much rather have it work like this (or something similar)

from my_module import fn
with patch('my_module.fn', return_value='hello') as patch_context:
    x = fn()
    # x contains real result from real call to fn()

Your attempt to patch with from my_module import fn does not work because the import statement creates a local symbol fn which points to whatever value fn has in my_module at the time of import . You later patch my_module.fn but it does not matter because you already have a local copy of fn .

If the file that contains the patch call is the main module (the file that python initially loaded), you should be able to do it by patching __main__.fn :

from my_module import fn
with patch('__main__.fn', return_value='hello') as patch_context:
    x = fn()

If the file that contains the patch call is loaded as a module from the main module then __main__ won't work and you need to pass the absolute module name of the module that contains your patch call to patch rather than __main__ .

You can see function like module object's static method. To patch a function func in module mymodule you can use

patch("mymodule.func", return_value=my_return)

You should take care of Where to patch and if the function is in the same module where you have the test should use "__main__.func" as patch argument.

patch like patch.object can be useed as decorator, context or by start() and stop() method.

Now when in a module you import a function from an other module like:

from mymodule import func as foo

You create a new reference to func in the new module called foo . Every time in this module you call foo you will use the reference to mymodule.func that you load when you imported it: if you whould like change this behavior you should patch foo in the new module.

To make it more clear I build an example where you have mymodule that contain func , module_a that include mymodule and use mymodule.func , module_b that use from mymodule import func as foo and use bot foo and mymodule.func

mymodule.py

def func():
    return "orig"

module_a.py

import mymodule

def a():
    return mymodule.func()

module_b.py

from mymodule import func as foo
import mymodule

def b_foo():
    return foo()

def b():
    return mymodule.func()

test.py

import unittest
from unittest.mock import *
import mymodule
import module_a
import module_b

class Test(unittest.TestCase):
    def test_direct(self):
        self.assertEqual(mymodule.func(), "orig")
        with patch("mymodule.func", return_value="patched"):
            self.assertEqual(mymodule.func(), "patched")

    def test_module_a(self):
        self.assertEqual(module_a.a(), "orig")
        with patch("mymodule.func", return_value="patched"):
            self.assertEqual(module_a.a(), "patched")

    def test_module_b(self):
        self.assertEqual(module_b.b(), "orig")
        with patch("mymodule.func", return_value="patched"):
            self.assertEqual(module_b.b(), "patched")
            self.assertEqual(module_b.b_foo(), "orig")
        with patch("module_b.foo", return_value="patched"):
            self.assertEqual(module_b.b(), "orig")
            self.assertEqual(module_b.b_foo(), "patched")    


if __name__ == '__main__':
    unittest.main()

In other words what really rules on choosing where to patch is how the function is referenced where you want use the patched version .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM