簡體   English   中英

如何對 PySpark 程序進行單元測試?

[英]How do I unit test PySpark programs?

我當前的 Java/Spark 單元測試方法通過使用“本地”實例化 SparkContext 並使用 JUnit 運行單元測試來工作(在此處詳細說明)。

必須組織代碼以在一個函數中執行 I/O,然后使用多個 RDD 調用另一個函數。

這很好用。 我有一個用 Java + Spark 編寫的經過高度測試的數據轉換。

我可以用 Python 做同樣的事情嗎?

我將如何使用 Python 運行 Spark 單元測試?

我也建議使用 py.test 。 py.test 可以輕松創建可重用的 SparkContext 測試裝置並使用它來編寫簡潔的測試函數。 您還可以專門化設備(例如創建一個 StreamingContext)並在您的測試中使用它們中的一個或多個。

我在 Medium 上寫了一篇關於這個主題的博文:

https://engblog.nextdoor.com/unit-testing-apache-spark-with-py-test-3b8970dc013b

這是帖子中的一個片段:

pytestmark = pytest.mark.usefixtures("spark_context")
def test_do_word_counts(spark_context):
    """ test word couting
    Args:
       spark_context: test fixture SparkContext
    """
    test_input = [
        ' hello spark ',
        ' hello again spark spark'
    ]

    input_rdd = spark_context.parallelize(test_input, 1)
    results = wordcount.do_word_counts(input_rdd)

    expected_results = {'hello':2, 'spark':3, 'again':1}  
    assert results == expected_results

如果您使用的是 Spark 2.x 和SparkSession這里有一個 pytest 解決方案。 我也在導入第三方包。

import logging

import pytest
from pyspark.sql import SparkSession

def quiet_py4j():
    """Suppress spark logging for the test context."""
    logger = logging.getLogger('py4j')
    logger.setLevel(logging.WARN)


@pytest.fixture(scope="session")
def spark_session(request):
    """Fixture for creating a spark context."""

    spark = (SparkSession
             .builder
             .master('local[2]')
             .config('spark.jars.packages', 'com.databricks:spark-avro_2.11:3.0.1')
             .appName('pytest-pyspark-local-testing')
             .enableHiveSupport()
             .getOrCreate())
    request.addfinalizer(lambda: spark.stop())

    quiet_py4j()
    return spark


def test_my_app(spark_session):
   ...

請注意,如果使用 Python 3,我必須將其指定為 PYSPARK_PYTHON 環境變量:

import os
import sys

IS_PY2 = sys.version_info < (3,)

if not IS_PY2:
    os.environ['PYSPARK_PYTHON'] = 'python3'

否則你會得到錯誤:

例外:worker 中的 Python 版本 2.7 與驅動程序 3.5 中的版本不同,PySpark 無法在不同的次要版本下運行。請檢查環境變量 PYSPARK_PYTHON 和 PYSPARK_DRIVER_PYTHON 是否設置正確。

假設您安裝了pyspark ,您可以在unittest使用下面的類進行 unitTest :

import unittest
import pyspark


class PySparkTestCase(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        conf = pyspark.SparkConf().setMaster("local[2]").setAppName("testing")
        cls.sc = pyspark.SparkContext(conf=conf)
        cls.spark = pyspark.SQLContext(cls.sc)

    @classmethod
    def tearDownClass(cls):
        cls.sc.stop()

例子:

class SimpleTestCase(PySparkTestCase):

    def test_with_rdd(self):
        test_input = [
            ' hello spark ',
            ' hello again spark spark'
        ]

        input_rdd = self.sc.parallelize(test_input, 1)

        from operator import add

        results = input_rdd.flatMap(lambda x: x.split()).map(lambda x: (x, 1)).reduceByKey(add).collect()
        self.assertEqual(results, [('hello', 2), ('spark', 3), ('again', 1)])

    def test_with_df(self):
        df = self.spark.createDataFrame(data=[[1, 'a'], [2, 'b']], 
                                        schema=['c1', 'c2'])
        self.assertEqual(df.count(), 2)

請注意,這會為每個類創建一個上下文。 使用setUp而不是setUpClass來獲取每個測試的上下文。 這通常會增加執行測試的大量開銷時間,因為創建新的 Spark 上下文目前很昂貴。

我使用pytest ,它允許測試裝置,因此您可以實例化 pyspark 上下文並將其注入到需要它的所有測試中。 類似的東西

@pytest.fixture(scope="session",
                params=[pytest.mark.spark_local('local'),
                        pytest.mark.spark_yarn('yarn')])
def spark_context(request):
    if request.param == 'local':
        conf = (SparkConf()
                .setMaster("local[2]")
                .setAppName("pytest-pyspark-local-testing")
                )
    elif request.param == 'yarn':
        conf = (SparkConf()
                .setMaster("yarn-client")
                .setAppName("pytest-pyspark-yarn-testing")
                .set("spark.executor.memory", "1g")
                .set("spark.executor.instances", 2)
                )
    request.addfinalizer(lambda: sc.stop())

    sc = SparkContext(conf=conf)
    return sc

def my_test_that_requires_sc(spark_context):
    assert spark_context.textFile('/path/to/a/file').count() == 10

然后,您可以通過調用py.test -m spark_local或在 YARN 中使用py.test -m spark_yarn在本地模式下運行測試。 這對我來說效果很好。

您可以通過在測試套件中的 DataFrame 上運行您的代碼並比較 DataFrame 列相等或兩個整個 DataFrame 的相等來測試 PySpark 代碼。

quinn 項目有幾個例子

為測試套件創建 SparkSession

使用此夾具創建一個 tests/conftest.py 文件,以便您可以輕松訪問測試中的 SparkSession。

import pytest
from pyspark.sql import SparkSession

@pytest.fixture(scope='session')
def spark():
    return SparkSession.builder \
      .master("local") \
      .appName("chispa") \
      .getOrCreate()

列相等示例

假設您想測試以下從字符串中刪除所有非單詞字符的函數。

def remove_non_word_characters(col):
    return F.regexp_replace(col, "[^\\w\\s]+", "")

您可以使用chispa庫中定義的assert_column_equality函數測試此函數。

def test_remove_non_word_characters(spark):
    data = [
        ("jo&&se", "jose"),
        ("**li**", "li"),
        ("#::luisa", "luisa"),
        (None, None)
    ]
    df = spark.createDataFrame(data, ["name", "expected_name"])\
        .withColumn("clean_name", remove_non_word_characters(F.col("name")))
    assert_column_equality(df, "clean_name", "expected_name")

DataFrame 相等示例

有些功能需要通過比較整個 DataFrame 來測試。 這是一個對 DataFrame 中的列進行排序的函數。

def sort_columns(df, sort_order):
    sorted_col_names = None
    if sort_order == "asc":
        sorted_col_names = sorted(df.columns)
    elif sort_order == "desc":
        sorted_col_names = sorted(df.columns, reverse=True)
    else:
        raise ValueError("['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format(
            sort_order=sort_order
        ))
    return df.select(*sorted_col_names)

這是您為此函數編寫的一個測試。

def test_sort_columns_asc(spark):
    source_data = [
        ("jose", "oak", "switch"),
        ("li", "redwood", "xbox"),
        ("luisa", "maple", "ps4"),
    ]
    source_df = spark.createDataFrame(source_data, ["name", "tree", "gaming_system"])

    actual_df = T.sort_columns(source_df, "asc")

    expected_data = [
        ("switch", "jose", "oak"),
        ("xbox", "li", "redwood"),
        ("ps4", "luisa", "maple"),
    ]
    expected_df = spark.createDataFrame(expected_data, ["gaming_system", "name", "tree"])

    assert_df_equality(actual_df, expected_df)

測試輸入/輸出

通常最好從 I/O 函數中抽象出代碼邏輯,這樣它們更容易測試。

假設你有一個這樣的函數。

def your_big_function:
    df = spark.read.parquet("some_directory")
    df2 = df.withColumn(...).transform(function1).transform(function2)
    df2.write.parquet("other directory")

最好像這樣重構代碼:

def all_logic(df):
  return df.withColumn(...).transform(function1).transform(function2)

def your_formerly_big_function:
    df = spark.read.parquet("some_directory")
    df2 = df.transform(all_logic)
    df2.write.parquet("other directory")

像這樣設計您的代碼可以讓您輕松地使用上面提到的列相等或 DataFrame 相等函數測試all_logic函數。 您可以使用your_formerly_big_function來測試your_formerly_big_function 通常最好避免在測試套件中使用 I/O(但有時不可避免)。

pyspark 有 unittest 模塊,可以如下使用

from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase

class MySparkTests(PySparkTestCase):
    def spark_session(self):
        return pyspark.SQLContext(self.sc)

    def createMockDataFrame(self):
         self.spark_session().createDataFrame(
            [
                ("t1", "t2"),
                ("t1", "t2"),
                ("t1", "t2"),
            ],
            ['col1', 'col2']
        )

前段時間我也遇到了同樣的問題,在閱讀了幾篇文章、論壇和一些 StackOverflow 的答案后,我最終為 pytest 編寫了一個小插件: pytest-spark

我已經使用它幾個月了,一般工作流程在 Linux 上看起來不錯:

  1. 安裝 Apache Spark(設置 JVM + 將 Spark 的發行版解壓到某個目錄)
  2. 安裝“pytest”+插件“pytest-spark”
  3. 在您的項目目錄中創建“pytest.ini”並在那里指定 Spark 位置。
  4. 像往常一樣通過 pytest 運行您的測試。
  5. 您可以選擇在測試中使用由插件提供的夾具“spark_context” - 它嘗試最小化輸出中的 Spark 日志。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM