[英]Pyspark unit testing
import unittest
import warnings
from datetime import datetime
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql.types import StringType, StructField, StructType, TimestampType, FloatType
from ohlcv_service.ohlc_gwa import datetime_col
class ReusedPySparkTestCase(unittest.TestCase):
sc_values = {}
@classmethod
def setUpClass(cls):
conf = (SparkConf().setMaster('local[2]')
.setAppName(cls.__name__)
.set('deploy.authenticate.secret', '111111'))
cls.sc = SparkContext(conf=conf)
cls.sc_values[cls.__name__] = cls.sc
cls.spark = (SparkSession.builder
.master('local[2]')
.appName('local-testing-pyspark-context')
.getOrCreate())
@classmethod
def tearDownClass(cls):
print('....calling stop tearDownClass, the content of sc_values=', cls.sc_values, '\n')
for key, sc in cls.sc_values.items():
print('....closing=', key, '\n')
sc.stop()
cls.sc_values.clear()
class TestDateTimeCol(ReusedPySparkTestCase):
def setUp(self):
# Ignore ResourceWarning: unclosed socket.socket!
warnings.simplefilter("ignore", ResourceWarning)
def test_datetime_col(self):
test_data_frame = self.create_data_frame(rows=[['GWA',
'2b600c2a-782f-4ccc-a675-bbbd7d91fde4',
'02fb81fa-91cf-4eab-a07e-0df3c107fbf8',
'2019-06-01T00:00:00.000Z',
0.001243179008694,
0.001243179008694,
0.001243179008694,
0.001243179008694,
0.001243179008694]],
columns=[StructField('indexType', StringType(), False),
StructField('id', StringType(), False),
StructField('indexId', StringType(), False),
StructField('timestamp', StringType(), False),
StructField('price', FloatType(), False),
StructField('open', FloatType(), False),
StructField('high', FloatType(), False),
StructField('low', FloatType(), False),
StructField('close', FloatType(), False)])
expected = self.create_data_frame(rows=[['GWA',
'2b600c2a-782f-4ccc-a675-bbbd7d91fde4',
'02fb81fa-91cf-4eab-a07e-0df3c107fbf8',
'2019-06-01T00:00:00.000Z',
'1559347200',
0.001243179008694,
0.001243179008694,
0.001243179008694,
0.001243179008694,
0.001243179008694]],
columns=[StructField('indexType', StringType(), False),
StructField('id', StringType(), False),
StructField('indexId', StringType(), False),
StructField('timestamp', StringType(), False),
StructField('datetime', TimestampType(), True),
StructField('price', FloatType(), False),
StructField('open', FloatType(), False),
StructField('high', FloatType(), False),
StructField('low', FloatType(), False),
StructField('close', FloatType(), False)])
print(expected)
convert_to_datetime = datetime_col(test_data_frame)
self.assertEqual(expected, convert_to_datetime)
def create_data_frame(self, rows, columns):
rdd = self.sc.parallelize(rows)
df = self.spark.createDataFrame(rdd.collect(), test_schema(columns=columns))
return df
def test_schema(columns):
return StructType(columns)
if __name__ == '__main__':
unittest.main()
错误
TimestampType can not accept object '1559347200' in type <class 'str'>
datetime_col 函数
def datetime_col(df):
return df.select("indexType", "id", "indexId", "timestamp",
(F.col("timestamp").cast(TimestampType)).alias("datetime"),
"price", "open", "high", "low", "close")
datetime col 函数将时间戳从字符串转换为时间戳格式。 这在 EMR-Zeppelin 笔记本中正常工作,但是当我尝试对其进行单元测试时,它会引发上述错误。 我本地的spark和pyspark版本是2.3.1。 如何解决此错误。 当我尝试将 spark df 转换为 pandas df 时,它会将时间戳转换为 +12。
我无法在您的EMR
设置中真正重现您的问题,您没有发布大量信息,无论如何我都无法设置它。 但是您的测试用例存在一些问题,我可以尝试帮助您解决。
您看到的错误消息是因为您无法将string
或更正确的int
直接转换为Timestamp
。 您需要使用to_unixtime
。 像这样的东西工作正常。
expected = self.create_data_frame(rows=[['GWA',
'2b600c2a-782f-4ccc-a675-bbbd7d91fde4',
'02fb81fa-91cf-4eab-a07e-0df3c107fbf8',
'2019-06-01T00:00:00.000Z',
None,
0.001243179008694,
0.001243179008694,
0.001243179008694,
0.001243179008694,
0.001243179008694]],
columns=[StructField('indexType', StringType(), False),
StructField('id', StringType(), False),
StructField('indexId', StringType(), False),
StructField('timestamp', StringType(), False),
StructField('datetime', TimestampType(), True),
StructField('price', FloatType(), False),
StructField('open', FloatType(), False),
StructField('high', FloatType(), False),
StructField('low', FloatType(), False),
StructField('close', FloatType(), False)])
expected = expected.withColumn('datetime', from_unixtime(F.lit(1559347200)).cast(TimestampType()))
第二个问题是您的datetime_col
函数可能在集群中正常工作(正如我所说的我无法真正复制),但它在本地不起作用。 以下方式肯定适用于两者。
def datetime_col(df):
return df.select("indexType", "id", "indexId", "timestamp",
(to_timestamp(F.col("timestamp"),
"yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")).alias("datetime"),
"price", "open", "high", "low", "close")
你需要设置你的时区才能正常工作(@你的setupClass
)。
cls.spark.conf.set("spark.sql.session.timeZone", "UTC")
最后,在您的assert
您必须collect
数据,以便比较数据框的内容。
self.assertEqual(expected.collect(), convert_to_datetime.collect())
希望能帮助到你。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.