简体   繁体   中英

Mock imported method for all test cases in a testclass

I have a few test cases like below examples in which I mock a function that is calling functions not available locally when unit tests are run.

The first statement in the function is_production() is not available so I just would like to mock it to have the function return True (or False in other cases in another Class if I get this working):

def is_production() -> bool:
    db = get_current("XXX")
    ...

class Production(TestCase):
    @patch('runner.runner.is_production')
    def test_modifications(self, get_content_mock):
        get_content_mock.return_value = True
        with tempfile.TemporaryDirectory() as bare_repo_folder:
            with git.Repo.init(bare_repo_folder, bare=False) as repo:
                ...
                # Execute test
                r = Runner(repo.working_dir, dummy.name)
                self.assertRaises(git.exc.RepositoryDirtyError, r.run)

    @patch('runner.runner.is_production')
    def test_no_tag(self, get_content_mock):
        get_content_mock.return_value = True
        with tempfile.TemporaryDirectory() as bare_repo_folder:
            with git.Repo.init(bare_repo_folder, bare=False) as repo:
                with tempfile.TemporaryFile(dir=bare_repo_folder, suffix=".py") as dummy:
                    r = Runner(repo.working_dir, dummy.name)
                    self.assertRaises(RuntimeError, r.run)

This work OK as test run successfully, so patching seems to work as I desire.

Since I need to add a bunch more I was hoping I could just patch a new class derived from TestCase and put all test cases in this derived class. This way I would only need to define the patch once as suggested in this suggestion on SO but I fail to get it working. This is what I currently have:

from unittest.mock import patch
from unittest import TestCase

...

class ProductionTestCase(TestCase):
    def setUp(self):
        self.patcher = patch("runner.runner.is_production", True)
        self.is_production = self.patcher.start()
        self.addClassCleanup(self.patcher.stop())


class Production(ProductionTestCase):
    def test_modifications(self):
        with tempfile.TemporaryDirectory() as bare_repo_folder:
        ...

    def test_no_tag(self):
        with tempfile.TemporaryDirectory() as bare_repo_folder:
        ...

Apparently I have not gotten the patch in place correctly as the actual function seems to get called when unit tests are run:

Error
Traceback (most recent call last):
  File "XXX\runner\tests\test_clinical.py", line 48, in test_modifications
    self.assertRaises(git.exc.RepositoryDirtyError, r.run)
  File "XXX\runner\runner.py", line 67, in run
    self.check_repo()
  File "XXX\runner\runner.py", line 52, in check_repo
    if is_production():
  File "XXX\runner\utils.py", line 33, in is_production
    db = get_current("XXX")
NameError: name 'get_current' is not defined

Any pointers in how to properly patch is_production in module runner.runner and have it return True when running the unit tests?

This worked for me:

@patch('runner.runner.is_production', True)
class Production(TestCase):
    
    def test_modifications(self, get_content_mock):
        ...

    def test_no_tag(self, get_content_mock):
        ...

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM