简体   繁体   中英

Calculate sum and average of a column in a pyspark dataframe and create a new row for the calculated values

I have a pyspark dataframe

Place   Month       Sector      Estimate    Profit  
USA     1/1/2020    Sector1     5944
Col     1/1/2020    Sector1     398
IND     1/1/2020    Sector1     25
USA     1/1/2020    Sector2                 6.9%
Col     1/1/2020    Sector2                 0.4%
China   1/1/2020    Sector2                 0.0%
Aus     1/1/2020    Sector2                 7.7%

I need to calculate the sum of all Estimate column(include all values) and average of all Profit column (excluding 0.0%) grouped by Month and Sector .

I need an extra value in Place field as Every Places having these sum and average values. So, my desired dataframe should look like this:

Place           Month       Sector      Estimate    Profit  
USA             1/1/2020    Sector1     5944
Col             1/1/2020    Sector1     398
IND             1/1/2020    Sector1     25
USA             1/1/2020    Sector2                 6.9%
Col             1/1/2020    Sector2                 0.4%
China           1/1/2020    Sector2                 0.0%
Aus             1/1/2020    Sector2                 7.7%
Every Places    1/1/2020    Sector1     6367
Every Places    1/1/2020    Sector2                 5%

I tried with this code, but I'm getting:

TypeError: Column is not iterable` error.

df1=df.withColumn('Place',lit('Every Places')) \
               .groupBy('Month','Sector') \
               .sum((col('Estimate'))),
               avg(F.col('Profit'))

How can I solve this?

You can first group by Month + Sector to calculate the sum of Estimate and the average of Profit then use union with the original dataframe to get the expected output:

import pyspark.sql.functions as F

df = spark.createDataFrame([
    ("USA", "1/1/2020", "Sector1", 5944, None), ("Col", "1/1/2020", "Sector1", 398, None),
    ("IND", "1/1/2020", "Sector1", 25, None), ("USA", "1/1/2020", "Sector2", None, "6.9%"),
    ("Col", "1/1/2020", "Sector2", None, "0.4%"), ("China", "1/1/2020", "Sector2", None, "0.0%"),
    ("Aus", "1/1/2020", "Sector2", None, "7.7%")], ["Place", "Month", "Sector", "Estimate", "Profit"]
)

grouped_df = df.withColumn(
    "Profit",
    F.regexp_extract("Profit", "(.+)%", 1) # extract percentage from string
).groupBy("Month", "Sector").agg(
    F.sum(F.col("Estimate")).alias("Estimate"),
    F.concat(
        F.sum("Profit") / F.sum(F.when(F.col("Profit") > 0.0, 1)), # exclude 0% from calculation
        F.lit("%")
    ).alias("Profit")
).withColumn(
    "Place",
    F.lit("Every Places")
)

df1 = df.unionByName(grouped_df)

df1.show()
#+------------+--------+-------+--------+------+
#|       Place|   Month| Sector|Estimate|Profit|
#+------------+--------+-------+--------+------+
#|         USA|1/1/2020|Sector1|    5944|  null|
#|         Col|1/1/2020|Sector1|     398|  null|
#|         IND|1/1/2020|Sector1|      25|  null|
#|         USA|1/1/2020|Sector2|    null|  6.9%|
#|         Col|1/1/2020|Sector2|    null|  0.4%|
#|       China|1/1/2020|Sector2|    null|  0.0%|
#|         Aus|1/1/2020|Sector2|    null|  7.7%|
#|Every Places|1/1/2020|Sector2|    null|  5.0%|
#|Every Places|1/1/2020|Sector1|  6367.0|  null|
#+------------+--------+-------+--------+------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM