簡體   English   中英

Python 使用 pytest_mock 在 function 中模擬多個查詢

[英]Python mock multiple queries in a function using pytest_mock

我正在為 function 編寫單元測試用例,其中包含多個 sql 查詢。我正在使用psycopg2模塊並試圖模擬cursor

應用程序.py

import psycopg2

def my_function():
    # all connection related code goes here ...

    query = "SELECT name,phone FROM customer WHERE name='shanky'"
    cursor.execute(query)
    columns = [i[0] for i in cursor.description]
    customer_response = []
    for row in cursor.fetchall():
        customer_response.append(dict(zip(columns, row)))

    query = "SELECT name,id FROM product WHERE name='soap'"
    cursor.execute(query)
    columns = [i[0] for i in cursor.description]
    product_response = []
    for row in cursor.fetchall():
        product_response.append(dict(zip(columns, row)))

    return product_response

測試.py

from pytest_mock import mocker
import psycopg2

def test_my_function(mocker):
    from my_module import app
    mocker.patch('psycopg2.connect')

    #first query
    mocked_cursor_one = psycopg2.connect.return_value.cursor.return_value
    mocked_cursor_one.description = [['name'],['phone']]
    mocked_cursor_one.fetchall.return_value = [('shanky', '347539593')]
    mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"

    #second query
    mocked_cursor_two = psycopg2.connect.return_value.cursor.return_value
    mocked_cursor_two.description = [['name'],['id']]
    mocked_cursor_two.fetchall.return_value = [('nirma', 12313)]
    mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"

    ret = app.my_function()
    assert ret == {'name' : 'nirma', 'id' : 12313}

但是嘲弄者總是取最后一個模擬 object (第二個查詢)。我已經嘗試了多種黑客,但沒有成功。 如何在一個 function 中模擬多個查詢並成功通過單元測試用例? 是否可以以這種方式編寫單元測試用例,還是需要將查詢拆分為不同的功能?

嘗試side_effectmocker.patch參數:

from unittest.mock import MagicMock
from pytest_mock import mocker
import psycopg2

def test_my_function(mocker):
    from my_module import app
    mocker.patch('psycopg2.connect', side_effect=[MagicMock(), MagicMock()])

    #first query
    mocked_cursor_one = psycopg2.connect().cursor.return_value  # note that we actually call psyocpg2.connect -- it's important
    mocked_cursor_one.description = [['name'],['phone']]
    mocked_cursor_one.fetchall.return_value = [('shanky', '347539593')]
    mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"

    #second query
    mocked_cursor_two = psycopg2.connect().cursor.return_value
    mocked_cursor_two.description = [['name'],['id']]
    mocked_cursor_two.fetchall.return_value = [('nirma', 12313)]
    mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"

    assert mocked_cursor_one is not mocked_cursor_two  # show that they are different

    ret = app.my_function()
    assert ret == {'name' : 'nirma', 'id' : 12313}

根據文檔side_effect允許您在每次調用修補的 object 時更改返回值:

如果您傳入一個可迭代對象,則它用於檢索一個迭代器,該迭代器必須在每次調用時產生一個值。 此值可以是要引發的異常實例,也可以是要從調用模擬返回的值

在對文檔進行了大量研究之后,我能夠在@Pavel Vergeev 建議的unittest mock decorator 和side_effect的幫助下實現這一點。我能夠編寫一個足以測試功能的單元測試用例。

from unittest import mock
from my_module import app

@mock.patch('psycopg2.connect')
def test_my_function(mocked_db):

    mocked_cursor = mocked_db.return_value.cursor.return_value

    description_mock = mock.PropertyMock()
    type(mocked_cursor).description = description_mock

    fetchall_return_one = [('shanky', '347539593')]

    fetchall_return_two = [('nirma', 12313)]

    descriptions = [
        [['name'],['phone']],
        [['name'],['id']]
    ]

    mocked_cursor.fetchall.side_effect = [fetchall_return_one, fetchall_return_two]

    description_mock.side_effect = descriptions

    ret = app.my_function()

    # assert whether called with mocked side effect objects
    mocked_db.assert_has_calls(mocked_cursor.fetchall.side_effect)

    # assert db query count is 2
    assert mocked_db.return_value.cursor.return_value.execute.call_count == 2

    # first query
    query1 = """
            SELECT name,phone FROM customer WHERE name='shanky'
            """
    assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][0] == query1

    # second query
    query2 = """
            SELECT name,id FROM product WHERE name='soap'
            """
    assert mocked_db.return_value.cursor.return_value.execute.call_args_list[1][0][0] == query2

    # assert the data of response
    assert ret == {'name' : 'nirma', 'id' : 12313}

除此之外,如果查詢中有動態參數,也可以通過以下方法進行斷言。

assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][1] = (parameter_name,)

因此,當執行第一個查詢時, cursor.execute(query,(parameter_name,)) at call_args_list[0][0][0]可以獲取並斷言查詢,在call_args_list[0][0][1]可以得到第一個參數parameter_name 類似地增加索引,可以獲取和斷言所有其他參數和不同的查詢。

正如我在前面的評論中提到的,使單元測試可移植的最好方法是開發一個完整的數據庫行為模擬。 我已經為 MySQL 完成了它,但對於所有數據庫來說幾乎都是一樣的。

首先,我喜歡在我正在使用的包上使用包裝類,它有助於在一個地方快速更改數據庫,而不是在代碼中的任何地方更改它。

這是我用作包裝器的示例:

現在,您需要模擬這個 MySQL class:

# _database.py
# -----------------------------------------------------------------------------
# Database Metaclass
# -----------------------------------------------------------------------------
"""Metaclass for Database implementation.
"""
# -----------------------------------------------------------------------------


import logging


logger = logging.getLogger(__name__)


class Database:
    """Database Metaclass"""

    def __init__(self, connect_func, **kwargs):
        self.connection = connect_func(**kwargs)

    def execute(self, statement, fetchall=True):
        """Execute a statement.

        Execute the statement passed as arugment.

        Args:
            statement (str): SQL Query or Command to execute.

        Returns:
            set: List of returned objects by the cursor.
        """
        cursor = self.connection.cursor()
        logger.debug(f"Executing: {statement}")
        cursor.execute(statement)
        if fetchall:
            return cursor.fetchall()
        else:
            return cursor.fetchone()

    def __del__(self):
        """Close connection on object deletion."""
        self.connection.close()

和 mysql 模塊:

# mysql.py
# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------
# MySQL Database Class
# -----------------------------------------------------------------------------
"""Class for MySQL Database connection."""
# -----------------------------------------------------------------------------


import logging
import mysql.connector

from . import _database


logger = logging.getLogger(__name__)


class MySQL(_database.Database):
    """Snowflake Database Class Wrapper.

    Attributes:
        connection (obj): Object returned from mysql.connector.connect
    """

    def __init__(self, autocommit=True, **kwargs):
        super().__init__(connect_func=mysql.connector.connect, **kwargs)
        self.connection.autocommit = autocommit

像這樣實例化: db = MySQL(user='...', password='...', ...)

這是數據文件:

# database_mock_data.json
{
    "customer": {
        "name": [
            "shanky",
            "nirma"
        ],
        "phone": [
            123123123,
            232342342
        ]
    },
    "product": {
        "name": [
            "shanky",
            "nirma"
        ],
        "id": [
            1,
            2
        ]
    }
}

模擬.py

# mocks.py
import json
import re
from . import mysql
_MOCK_DATA_PATH = 'database_mock_data.json'


class MockDatabase(MySQL):
    """
    """
    def __init__(self, **kwargs):
        self.connection = MockConnection()


class MockConnection:
    """
    Mock the connection object by returning a mock cursor.
    """
    @staticmethod
    def cursor():
        return MockCursor()


class MockCursor:
    """
    The Mocked Cursor

    A call to execute() will initiate the read on the json data file and will set
    the description object (containing the column names usually).

    You could implement an update function like `_json_sql_update()`
    """
    def __init__(self):
        self.description = []
        self.__result = None

    def execute(self, statement):
        data = _read_json_file(_MOCK_DATA_PATH)
        if statement.upper().startswith('SELECT'):
            self.__result, self.description = _json_sql_select(data, statement)

    def fetchall(self):
        return self.__result

    def fetchone(self):
        return self.__result[0]


def _json_sql_select(data, query):
    """
    Takes a dictionary and returns the values from a sql query.
    NOTE: It does not work with other where clauses than '='.
          Also, note that a where statement is expected.
    :param (dict) data: Dictionary with the following structure:
                        {
                            'tablename': {
                                'column_name_1': ['value1', 'value2],
                                'column_name_2': ['value1', 'value2],
                                ...
                            },
                            ...
                        }
    :param (str) query: An update sql query as:
                        `update TABLENAME set column_name_1='value'
                        where column_name_2='value1'`
    :return: List of list of values and header description
    """
    try:
        match = (re.search("select(.*)from(.*)where(.*)[;]?", query,
                 re.IGNORECASE | re.DOTALL).groups())
    except AttributeError:
        print("Select Query pattern mismatch... {}".format(query))
        raise

    # Parse values from the select query
    tablename = match[1].strip().upper()

    columns = [col.strip().upper() for col in match[0].split(",")]
    if columns == ['*']:
        columns = data[tablename].keys()

    where = [cmd.upper().strip().replace(' ', '')
             for cmd in match[2].split('and')]

    # Select values
    selected_values = []
    nb_lines = len(list(data[tablename].values())[0])
    for i in range(nb_lines):
        is_match = True
        for condition in where:
            key_condition, value_condition = (_clean_string(condition)
                                              .split('='))
            if data[tablename][key_condition][i].upper() != value_condition:
                # Set flag to yes
                is_match = False
        if is_match:
            sub_list = []
            for column in columns:
                sub_list.append(data[tablename][column][i])
            selected_values.append(sub_list)

    # Usual descriptor has nested list
    description = zip(columns, ['...'] * len(columns))

    return selected_values, description


def _read_json_file(file_path):
    with open(file_path, 'r') as f_in:
        data = json.load(f_in)
    return data

然后你在 test_module_yourfunction.py 中進行測試

import pytest

def my_function(db, query):
    # Code goes here

@pytest.fixture
def db_connection():
    return MockDatabase()


@pytest.mark.parametrize(
    ("query", "expected"),
    [
        ("SELECT name,phone FROM customer WHERE name='shanky'", {'name' : 'nirma', 'id' : 12313}),
        ("<second query goes here>", "<second result goes here>")
    ]
)
def test_my_function(db_connection, query, expected):
    assert my_function(db_connection, query) == expected

現在很抱歉,如果您無法復制/粘貼此代碼並使其正常工作,但您會感覺到:) 只是想提供幫助

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM