简体   繁体   中英

Python unittest: Mock an external library function called from a class object

Hello I have the following code;

I am trying to test the load function inside file_a; download is function in an external module I imported

file_a.py

from foo import download

class Bar()
    __init__(self, arg_1):
        self.var = arg_1

    def load(self):
        if self.var == "latest_lib":
            download("latest_lib")

I wrote the test like

test.py

@patch(file_a.download)
def test_download():
    import file_a
    bar = file_a.Bar("latest_lib")
    bar.load()
    file_a.download.assert_called() 

But it seems like the bar object is not calling the mock download rather the imported download. How can I fix this problem and make my test pass?

I tried to use your own code to see where was wrong. There were some syntax errors which doesn't really matter, but the main issue is that you should pass strings to patch to make it work.

Here's my modification to your code that made it all happen:

# file_a.py

from pprint import pprint as pp

class Bar():
    def __init__(self, arg_1):
        self.var = arg_1

    def load(self):
        if self.var == "latest_lib":
            pp("latest_lib")

And:

# test_file_a.py

import unittest
from unittest.mock import patch


class TestStringMethods(unittest.TestCase):
    @patch("file_a.pp")
    def test_download(self, pp):
        import file_a
        bar = file_a.Bar("latest_lib")
        bar.load()
        pp.assert_called()


if __name__ == "__main__":
    unittest.main()

NOTES:

  1. You need to pass a string to patch to make it mock the object in runtime.
  2. You have to receive the mocked object inside your test function.
  3. I have used class-based tests here because I wanted to work with the unittest standard library, but you don't have to do the same if you're using pytest .

Also one final note from the official doc :

The basic principle is that you patch where an object is looked up, which is not necessarily the same place as where it is defined.

I think you are missing the mock setup:

@patch("foo.download")
def test_download(mock_download):
    from file_a import Bar

    Bar("latest_lib").load()
    mock_download.assert_called()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM