简体   繁体   中英

Interpolation in PySpark throws java.lang.IllegalArgumentException

I don't know how to interpolate in PySpark when the DataFrame contains many columns. Let me xplain.

from pyspark.sql.functions import to_timestamp

df = spark.createDataFrame([
    ("John",  "A", "2018-02-01 03:00:00", 60),  
    ("John",  "A", "2018-02-01 03:03:00", 66),  
    ("John",  "A", "2018-02-01 03:05:00", 70),  
    ("John",  "A", "2018-02-01 03:08:00", 76),  
    ("Mo",    "A", "2017-06-04 01:05:00", 10),  
    ("Mo",    "A", "2017-06-04 01:07:00", 20),  
    ("Mo",    "B", "2017-06-04 01:10:00", 35),  
    ("Mo",    "B", "2017-06-04 01:11:00", 40),
], ("webID", "aType", "timestamp", "counts")).withColumn(
  "timestamp", to_timestamp("timestamp")
)

I need to group by webID and interpolate counts values at 1 minute interval. However, when I apply the below-shown code,

from operator import attrgetter
from pyspark.sql.types import StructType
from pyspark.sql.functions import pandas_udf, PandasUDFType

def resample(schema, freq, timestamp_col = "timestamp",**kwargs):
    @pandas_udf(
        StructType(sorted(schema, key=attrgetter("name"))), 
        PandasUDFType.GROUPED_MAP)
    def _(pdf):
        pdf.set_index(timestamp_col, inplace=True)
        pdf = pdf.resample(freq).interpolate()
        pdf.ffill(inplace=True)
        pdf.reset_index(drop=False, inplace=True)
        pdf.sort_index(axis=1, inplace=True)
        return pdf
    return _


df.groupBy("webID").apply(resample(df.schema, "60S")).show()

Error:

py4j.protocol.Py4JJavaError: An error occurred while calling o371.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 77 in stage 31.0 failed 4 times, most recent failure: Lost task 77.3 in stage 31.0 (TID 812, 27faa516aadb4c40b7d7586d7493143c0021c825663, executor 2): java.lang.IllegalArgumentException
    at java.nio.ByteBuffer.allocate(ByteBuffer.java:334)

Set the environment variable ARROW_PRE_0_15_IPC_FORMAT=1 .

https://spark.apache.org/docs/3.0.0-preview/sql-pyspark-pandas-with-arrow.html#compatibiliy-setting-for-pyarrow--0150-and-spark-23x-24x

def resample(schema, freq, timestamp_col = "timestamp",**kwargs):
    @pandas_udf(
        StructType(sorted(schema, key=attrgetter("name"))), 
        PandasUDFType.GROUPED_MAP)
    def _(pdf):
        import os                                      # add this line
        os.environ['ARROW_PRE_0_15_IPC_FORMAT']='1'    # add this line
        pdf.set_index(timestamp_col, inplace=True)
        pdf = pdf.resample(freq).interpolate()
        pdf.ffill(inplace=True)
        pdf.reset_index(drop=False, inplace=True)
        pdf.sort_index(axis=1, inplace=True)
        return pdf
    return _

You can also do the resampling in pyspark without using an pandas UDF (or python UDF). The below solution will perform better for big datasets compared to the pandas UDF method and also prevents the error you get:

import pyspark.sql.functions as F

df = spark.createDataFrame([
    ("John",  "A", "2018-02-01 03:00:00", 60),  
    ("John",  "A", "2018-02-01 03:03:00", 66),  
    ("John",  "A", "2018-02-01 03:05:00", 70),  
    ("John",  "A", "2018-02-01 03:08:00", 76),  
    ("Mo",    "A", "2017-06-04 01:05:00", 10),  
    ("Mo",    "A", "2017-06-04 01:07:00", 20),  
    ("Mo",    "B", "2017-06-04 01:10:00", 35),  
    ("Mo",    "B", "2017-06-04 01:11:00", 40),
], ("webID", "aType", "timestamp", "counts")).withColumn(
  "timestamp", F.to_timestamp("timestamp")
)
resample_interval = 1*60  # Resample interval size in seconds

df_interpolated = (
  df
  # Get timestamp and Counts of previous measurement via window function
  .selectExpr(
    "webID",
    "aType",
    "LAG(Timestamp) OVER (PARTITION BY webID ORDER BY Timestamp ASC) as PreviousTimestamp",
    "Timestamp as NextTimestamp",
    "LAG(Counts) OVER (PARTITION BY webID ORDER BY Timestamp ASC) as PreviousCounts",
    "Counts as NextCounts",
  )
  # To determine resample interval round up start and round down end timeinterval to nearest interval boundary
  .withColumn("PreviousTimestampRoundUp", F.expr(f"to_timestamp(ceil(unix_timestamp(PreviousTimestamp)/{resample_interval})*{resample_interval})"))
  .withColumn("NextTimestampRoundDown", F.expr(f"to_timestamp(floor(unix_timestamp(NextTimestamp)/{resample_interval})*{resample_interval})"))
  # Make sure we don't get any negative intervals (whole interval is within resample interval)
  .filter("PreviousTimestampRoundUp<=NextTimestampRoundDown")
  # Create resampled time axis by creating all "interval" timestamps between previous and next timestamp
  .withColumn("Timestamp", F.expr(f"explode(sequence(PreviousTimestampRoundUp, NextTimestampRoundDown, interval {resample_interval} second)) as Timestamp"))
  # Sequence has inclusive boundaries for both start and stop. Filter out duplicate Counts if original timestamp is exactly a boundary.
  .filter("Timestamp<NextTimestamp")
  # Interpolate Counts between previous and next
  .selectExpr(
    "webID",
    "aType",
    "Timestamp", 
    """(unix_timestamp(Timestamp)-unix_timestamp(PreviousTimestamp))
        /(unix_timestamp(NextTimestamp)-unix_timestamp(PreviousTimestamp))
        *(NextCounts-PreviousCounts) 
        +PreviousCounts
        as Counts"""
  )
)

I recently have written a blogpost that explains the reasoning behind the method and compares the performance with the pandas UDF method that you are using: https://medium.com/delaware-pro/interpolate-big-data-time-series-in-native-pyspark-d270d4b592a1

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM