简体   繁体   中英

How to patch/mock import?

I'm writing tests for airflow dag and running into issue mocking/patching the dag.

# dag.py
from airflow.models import Variable

ENVIRONMENT = Variable.get("environment")
# test_dag.py
import dag

class TestDAG(TestCase):
    def test_something(self):
        pass

Because I'm just setting variable outside of function or class, it runs Variable.get() during import. This will be give me a SQLAlchemy error cause it's trying to connect to a db and fetch variable.

sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) no such table: variable
[SQL: SELECT variable.val AS variable_val, variable.id AS variable_id, variable."key" AS variable_key, variable.is_encrypted AS variable_is_encrypted 
FROM variable 
WHERE variable."key" = ?
 LIMIT ? OFFSET ?]
[parameters: ('environment', 1, 0)]

Is there a way to patch/mock airflow.models.Variable before it's imported?

You'll need to defer importing the file until you can set a Variable value into the test database. A startTestRun method would be the perfect place.

I have some workarounds but not a definite answer:

  1. You can move this line ENVIRONMENT = Variable.get("environment") into inside a function, instead of global. This way, it will not be executed when imported and you can add this mock to conftest.py :
@pytest.fixture(autouse=True)
def mock_airflow_variables(mocker):
        mocker.patch.object(target=Variable, attribute="get", return_value="test")
  1. You can import the module inside the test function. This way, the mock will be set before calling import.

You can mock the import before importing the file

import unittest
from unittest.mock import MagicMock

import sys


class TestDAG(unittest.TestCase):

    def test_something(self):
        sys.modules['airflow.models'] = MagicMock()
    
        # This returns the MagicMock instance
        from airflow.models import Variable 

        # Set the return_value of the .get() call
        Variable.get.return_value = "TEST" 

        # import the dag after
        import dag

        Variable.get.assert_called_once()
        assert dag.ENVIRONMENT == "TEST"


if __name__ == '__main__':
    unittest.main()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM