简体   繁体   English

Python 单元测试模拟 pyspark 链

[英]Python unittest mock pyspark chain

I'd like to write some unit tests for simple methods which have pyspark code.我想为具有 pyspark 代码的简单方法编写一些单元测试。

def do_stuff(self, df1: DataFrame, df2_path: str, df1_key: str, df2_key: str) -> DataFrame:
    df2 = self.spark.read.format('parquet').load(df2_path)
    return df1.join(df2, [f.col(df1_key) == f.col(df2_key)], 'left')

How can I mock the spark read part?如何模拟火花读取部分? I've tried this:我试过这个:

@patch("class_to_test.SparkSession")
def test_do_stuff(self, mock_spark: MagicMock) -> None:
    spark = MagicMock()
    spark.read.return_value.format.return_value.load.return_value = \
        self.spark.createDataFrame([(1, 2)], ["key2", "c2"])
    mock_spark.return_value = spark

    input_df = self.spark.createDataFrame([(1, 1)], ["key1", "c1"])
    actual_df = ClassToTest().do_stuff(input_df, "df2", "key1", "key2")
    expected_df = self.spark.createDataFrame([(1, 1, 1, 2)], ["key1", "c1", "key2", "c2"])
    assert_pyspark_df_equal(actual_df, expected_df)

But it fails with this error:但它失败并出现此错误:
py4j.Py4JException: Method join([class java.util.ArrayList, class org.apache.spark.sql.Column, class java.lang.String]) does not exist
Looks like the mocking didn't work as I expected, what should I do with it so the spark.read.load returns the test dataframe that I specified?看起来 mocking 没有像我预期的那样工作,我应该怎么做才能让 spark.read.load 返回我指定的测试 dataframe?

Edit: full code here编辑:完整代码在这里

You can do it using PropertyMock .您可以使用PropertyMock来做到这一点。 Here is an example:这是一个例子:

# test.py
import unittest
from unittest.mock import patch, PropertyMock, Mock

from pyspark.sql import SparkSession, DataFrame, functions as f
from pyspark_test import assert_pyspark_df_equal


class ClassToTest:
    def __init__(self) -> None:
        self._spark = SparkSession.builder.getOrCreate()

    @property
    def spark(self):
        return self._spark

    def do_stuff(self, df1: DataFrame, df2_path: str, df1_key: str, df2_key: str) -> DataFrame:
        df2 = self.spark.read.format('parquet').load(df2_path)
        return df1.join(df2, [f.col(df1_key) == f.col(df2_key)], 'left')


class TestClassToTest(unittest.TestCase):
    def setUp(self) -> None:
        self.spark = SparkSession.builder.getOrCreate()

    def test_do_stuff(self) -> None:
        # let's say ClassToTest().spark.read.format().load() will return a DataFrame
        with patch(
            # change __main__ to your module...
            '__main__.ClassToTest.spark',
            new_callable=PropertyMock,
            return_value=Mock(
                # read property
                read=Mock(
                    # format() method
                    format=Mock(
                        return_value=Mock(
                            # load() method result:
                            load=Mock(return_value=self.spark.createDataFrame([(1, 2)], ['key2', 'c2']))))))
        ):
            input_df = self.spark.createDataFrame([(1, 1)], ['key1', 'c1'])
            df = ClassToTest().do_stuff(input_df, 'df2_path', 'key1', 'key2')
            assert_pyspark_df_equal(
                df,
                self.spark.createDataFrame([(1, 1, 1, 2)], ['key1', 'c1', 'key2', 'c2'])
            )


if __name__ == '__main__':
    unittest.main()

Let's check:让我们检查:

python test.py
# result:
----------------------------------------------------------------------
Ran 1 test in 7.460s

OK

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM