简体   繁体   中英

Spark dataFrame taking too long to display after updating its columns

I have a dataFrame of approx. 4 million rows and 35 columns as input.

All I do to this dataFrame is the following steps:

  • For a list of given columns, I calculate a sum for a given list of group features and joined it as new column to my input dataFrame
  • I drop each new column sum right after I joined it to the dataFrame.

Therefore we end up with the same dataFrame as we started from (in theory).

However, I noticed that if my list of given columns gets too big (from more than 6 columns), the output dataFrame becomes impossible to manipulate. Even a simple display takes 10 minutes.

Here is an example of my code (df is my input dataFrame):

  for c in list_columns:
    df = df.join(df.groupby(list_group_features).agg(sum(c).alias('sum_' + c)), list_group_features)
    df = df.drop('sum_' + c)

This happens due to the inner workings of Spark and its lazy evaluation.

What Spark does when you call groupby , join , agg , it attaches these calls to the plan of the df object. So even though it is not executing anything on the data, you are creating a large execution plan which is internally stored in the Spark DataFrame object.

Only when you call an action ( show , count , write , etc.), Spark optimizes the plan and executes it. If the plan is too large, the optimization step can take a while to perform. Also remember that the plan optimization is happening on the driver, not on the executors. So if your driver is busy or overloaded, it delays spark plan optimization step as well.

It is useful to remember that joins are expensive operations in Spark, both for optimization and execution. If you can, you should always avoid joins when operating on a single DataFrame and utilise the window functionality instead. Joins should only be used if you are joining different dataframes from different sources (different tables).

A way to optimize your code would be:

import pyspark
import pyspark.sql.functions as f

w = pyspark.sql.Window.partitionBy(list_group_features)
agg_sum_exprs = [f.sum(f.col(c)).alias("sum_" + c).over(w) for c in list_columns]
res_df = df.select(df.columns + agg_sum_exprs)

This should be scalable and fast for large list_group_features and list_columns lists.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM