繁体   English   中英

Pyspark 单元测试

[英]Pyspark unit testing

import unittest
import warnings
from datetime import datetime

from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql.types import StringType, StructField, StructType, TimestampType, FloatType

from ohlcv_service.ohlc_gwa import datetime_col


class ReusedPySparkTestCase(unittest.TestCase):
    sc_values = {}

    @classmethod
    def setUpClass(cls):
        conf = (SparkConf().setMaster('local[2]')
                .setAppName(cls.__name__)
                .set('deploy.authenticate.secret', '111111'))
        cls.sc = SparkContext(conf=conf)
        cls.sc_values[cls.__name__] = cls.sc
        cls.spark = (SparkSession.builder
                     .master('local[2]')
                     .appName('local-testing-pyspark-context')
                     .getOrCreate())

    @classmethod
    def tearDownClass(cls):
        print('....calling stop tearDownClass, the content of sc_values=', cls.sc_values, '\n')
        for key, sc in cls.sc_values.items():
            print('....closing=', key, '\n')
            sc.stop()

        cls.sc_values.clear()


class TestDateTimeCol(ReusedPySparkTestCase):

    def setUp(self):
        # Ignore ResourceWarning: unclosed socket.socket!
        warnings.simplefilter("ignore", ResourceWarning)

    def test_datetime_col(self):
        test_data_frame = self.create_data_frame(rows=[['GWA',
                                                        '2b600c2a-782f-4ccc-a675-bbbd7d91fde4',
                                                        '02fb81fa-91cf-4eab-a07e-0df3c107fbf8',
                                                        '2019-06-01T00:00:00.000Z',
                                                        0.001243179008694,
                                                        0.001243179008694,
                                                        0.001243179008694,
                                                        0.001243179008694,
                                                        0.001243179008694]],
                                                 columns=[StructField('indexType', StringType(), False),
                                                          StructField('id', StringType(), False),
                                                          StructField('indexId', StringType(), False),
                                                          StructField('timestamp', StringType(), False),
                                                          StructField('price', FloatType(), False),
                                                          StructField('open', FloatType(), False),
                                                          StructField('high', FloatType(), False),
                                                          StructField('low', FloatType(), False),
                                                          StructField('close', FloatType(), False)])
        expected = self.create_data_frame(rows=[['GWA',
                                                 '2b600c2a-782f-4ccc-a675-bbbd7d91fde4',
                                                 '02fb81fa-91cf-4eab-a07e-0df3c107fbf8',
                                                 '2019-06-01T00:00:00.000Z',
                                                 '1559347200',
                                                 0.001243179008694,
                                                 0.001243179008694,
                                                 0.001243179008694,
                                                 0.001243179008694,
                                                 0.001243179008694]],
                                          columns=[StructField('indexType', StringType(), False),
                                                   StructField('id', StringType(), False),
                                                   StructField('indexId', StringType(), False),
                                                   StructField('timestamp', StringType(), False),
                                                   StructField('datetime', TimestampType(), True),
                                                   StructField('price', FloatType(), False),
                                                   StructField('open', FloatType(), False),
                                                   StructField('high', FloatType(), False),
                                                   StructField('low', FloatType(), False),
                                                   StructField('close', FloatType(), False)])
        print(expected)
        convert_to_datetime = datetime_col(test_data_frame)
        self.assertEqual(expected, convert_to_datetime)

    def create_data_frame(self, rows, columns):
        rdd = self.sc.parallelize(rows)
        df = self.spark.createDataFrame(rdd.collect(), test_schema(columns=columns))
        return df


def test_schema(columns):
    return StructType(columns)


if __name__ == '__main__':
    unittest.main()

错误

TimestampType can not accept object '1559347200' in type <class 'str'>

datetime_col 函数

def datetime_col(df):
      return df.select("indexType", "id", "indexId", "timestamp",
                     (F.col("timestamp").cast(TimestampType)).alias("datetime"),
                     "price", "open", "high", "low", "close")

datetime col 函数将时间戳从字符串转换为时间戳格式。 这在 EMR-Zeppelin 笔记本中正常工作,但是当我尝试对其进行单元测试时,它会引发上述错误。 我本地的spark和pyspark版本是2.3.1。 如何解决此错误。 当我尝试将 spark df 转换为 pandas df 时,它会将时间戳转换为 +12。

我无法在您的EMR设置中真正重现您的问题,您没有发布大量信息,无论如何我都无法设置它。 但是您的测试用例存在一些问题,我可以尝试帮助您解决。

您看到的错误消息是因为您无法将string或更正确的int直接转换为Timestamp 您需要使用to_unixtime 像这样的东西工作正常。

expected = self.create_data_frame(rows=[['GWA',
                                         '2b600c2a-782f-4ccc-a675-bbbd7d91fde4',
                                         '02fb81fa-91cf-4eab-a07e-0df3c107fbf8',
                                         '2019-06-01T00:00:00.000Z',
                                         None,
                                         0.001243179008694,
                                         0.001243179008694,
                                         0.001243179008694,
                                         0.001243179008694,
                                         0.001243179008694]],
                                  columns=[StructField('indexType', StringType(), False),
                                           StructField('id', StringType(), False),
                                           StructField('indexId', StringType(), False),
                                           StructField('timestamp', StringType(), False),
                                           StructField('datetime', TimestampType(), True),
                                           StructField('price', FloatType(), False),
                                           StructField('open', FloatType(), False),
                                           StructField('high', FloatType(), False),
                                           StructField('low', FloatType(), False),
                                           StructField('close', FloatType(), False)])
expected = expected.withColumn('datetime', from_unixtime(F.lit(1559347200)).cast(TimestampType()))

第二个问题是您的datetime_col函数可能在集群中正常工作(正如我所说的我无法真正复制),但它在本地不起作用。 以下方式肯定适用于两者。

def datetime_col(df):
    return df.select("indexType", "id", "indexId", "timestamp",
                     (to_timestamp(F.col("timestamp"), 
                                   "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")).alias("datetime"),
                     "price", "open", "high", "low", "close")

你需要设置你的时区才能正常工作(@你的setupClass )。

cls.spark.conf.set("spark.sql.session.timeZone", "UTC")

最后,在您的assert您必须collect数据,以便比较数据框的内容。

self.assertEqual(expected.collect(), convert_to_datetime.collect())

希望能帮助到你。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM