简体   繁体   中英

Using Python unittesting libraries (unittest, mock), how to assert if a method of class B was called within a method of class A?

Assuming the following setup:

class A:
    def __init__(self, nodes):
        self.nodes=nodes

    def update(self, bool_a=True):
        if bool_a:
            for n in self.nodes:
                if hasattr(self.nodes[n], 'update'):
                    self.nodes[n].update()

class B:
    def __init__(self, int_attr=5):
        self.int_attr=int_attr

    def update(self):
        self.int_attr = 0

Let us assume that the list of nodes in class A is in fact a list of instances of class B.

How do I write a unit test for the update method of class A to check whether the update method of each class B node contained in self.nodes of class A was called?

In a more general setup, let us assume that there are multiple classes implementing update method and can be nodes within self.nodes of class A. How do I check that all update methods of self.nodes members were called?

I have tried the following, unsuccessfully:

mock_obj = MagicMock()
@patch('module.A.update', return_value=mock_obj)
def test_update(self, mock_obj):
    nodes = {}
    nodes['first'] = B(int_attr=1)
    nodes['second'] = B(int_attr=2)
    test_A = module.A(nodes=nodes)
    test_A.update(bool_A=True)
    self.assertTrue(mock_obj.called)

as suggested in mocking a function within a class method .

Edit: If we assume this particular case:

import unittest
import mock
from unittest import TestCase

class A:
    def __init__(self, nodes):
        self.nodes=nodes

    def update(self, bool_a=True):
        if bool_a:
            to_update = [n for n in self.nodes]
            while len(to_update) > 0:
                if hasattr(self.nodes[to_update[-1]], 'update'):
                    self.nodes[to_update[-1]].update()
                    print('Update called.')
                    if self.nodes[to_update[-1]].is_updated:
                        to_update.pop()

class B:
    def __init__(self, int_attr=5):
        self.int_attr=int_attr
        self.is_updated = False

    def update(self):
        self.int_attr = 0
        self.is_updated = True

class TestEnsemble(TestCase):
    def setUp(self):
        self.b1 = B(1)
        self.b2 = B(2)
        self.b3 = B(3)
        self.nodes = {}
        self.nodes['1'] = self.b1
        self.nodes['2'] = self.b2
        self.nodes['3'] = self.b3
        self.a = A(self.nodes)

    @mock.patch('module.B.update')
    def test_update(self, mock_update):
        mock_update.return_value = None
        self.a.update()
        with self.subTest():
            self.assertEqual(mock_update.call_count, 3)

Running unittest for this case results in an endless loop, since is_updated attribute never gets set to True because the update method of class B is mocked. How to measure the amount of time B.update was called within A.update in this case?

Update: Tried this:

@mock.patch('dummy_script.B')
def test_update(self, mock_B):
    self.a.update()
    with self.subTest():
        self.assertEqual(mock_B.update.call_count, 3)

The update function indeed runs 3 times now (I see it in the console output, as "Update called." is printed out three times), but the call_count of update method stays zero. Am I inspecting the wrong attribute / object?

How do I write a unit test for TestA.test_update() to see if B.update() was called?

This is just to give some ideas.

import mock
import unittest
import A
import B

class TestB(unittest.TestCase):

    # only mock away update method of class B, this is python2 syntax
    @mock.patch.object(B, 'update')
    def test_update(self, mockb_update):
        # B.update() does not return anything
        mockb_update.return_value = None
        nodes = {}
        nodes['first'] = B(int_attr=1)
        nodes['second'] = B(int_attr=2)
        test_A = A(nodes)
        test_A.update(bool_A=True)
        self.assertTrue(mockb_update.called)

How do I check all B.update() were called for all A.nodes ?

    # same everthing except this
    self.assertEqual(mockb_update.call_count, 2)

Running into an endless loop when B.is_udpated is not mocked after OP updated the code

mock a B.is_updated inside __init__ or mock class __init__ is a topic more complicated that the original post

Here is a few thoughts, B.is_updated cannot be just mock.patch , it is only available after the class B is initiated. so the option is to

a) mock B.__init__ , or a class constructor

b) mock the whole class B , which is easier in your case, set is_updated to be True , will end the endless loop.

Combining answers from @Gang and @jonrsharpe, the following code snippets solve the questions I have asked above:

How do I write a unit test for TestA.test_update() to see if B.update() was called? See @Gangs' answer.

How do I check all B.update() were called for all A.nodes? See @Gangs' answer.

Running into an endless loop when B.is_udpated is not mocked after OP updated the code

As @jonrsharpe suggested, here the solution is to create one mock object per B instance in nodes and check for function calls separately:

class TestA(TestCase):

    @mock.patch('module.B')
    @mock.patch('module.B')
    @mock.patch('module.B')
    def test_update(self, mock_B1, mock_B2, mock_B3):
        nodes = {}
        nodes['1'] = mock_B1
        nodes['2'] = mock_B2
        nodes['3'] = mock_B3
        a = A(nodes)
        a.update()
        with self.subTest():
            self.assertEqual(mock_B1.update.call_count, 1)
        with self.subTest():
            self.assertEqual(mock_B2.update.call_count, 1)
        with self.subTest():
            self.assertEqual(mock_B3.update.call_count, 1)

Additionally, if you want for some reason to execute mocked functions (in case they set some flags or variables which affect runtime), one can write a test like this:

def test_fit_skip_ancestors_all(self):
    nodes = {}
    nodes['1'] = mock_B1
    nodes['2'] = mock_B2
    nodes['3'] = mock_B3
    a = A(nodes)
    with mock.patch.object(A.nodes['1'],'update',wraps=A.nodes['1'].update) as mock_B1, \
mock.patch.object(A.nodes['2'], 'update', wraps=A.nodes['2'].update) as mock_B2, \
mock.patch.object(A.nodes['3'], 'update', wraps=A.nodes['3'].update) as mock_B3:

    a.update()
    with self.subTest():
        self.assertEqual(mock_B1.call_count, 1)
    with self.subTest():
        self.assertEqual(mock_B2.call_count, 1)
    with self.subTest():
        self.assertEqual(mock_B3.call_count, 1)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM