简体   繁体   中英

pySpark count IDs on condition

I have the following dataset and working with PySpark

df = sparkSession.createDataFrame([(5, 'Samsung', '2018-02-23'),
                                   (8, 'Apple', '2018-02-22'),
                                   (5, 'Sony', '2018-02-21'),
                                   (5, 'Samsung', '2018-02-20'),
                                   (8, 'LG', '2018-02-20')],
                                   ['ID', 'Product', 'Date']
                                  )

+---+-------+----------+
| ID|Product|      Date|
+---+-------+----------+
|  5|Samsung|2018-02-23|
|  8|  Apple|2018-02-22|
|  5|   Sony|2018-02-21|
|  5|Samsung|2018-02-20|
|  8|     LG|2018-02-20|
+---+-------+----------+
# Each ID will appear ALWAYS at least 2 times (do not consider the case of unique IDs in this df)

Each ID should increment the PRODUCT counter only when it represents the higher frequency. In case of equal frequency, the most recent date should decide which product receives +1.

From the sample above, the desired output would be:

+-------+-------+
|Product|Counter|
+-------+-------+
|Samsung|      1|
|  Apple|      1|
|   Sony|      0|
|     LG|      0|
+-------+-------+


# Samsung - 1 (preferred twice by ID=5)
# Apple - 1 (preferred by ID=8 more recently than LG)
# Sony - 0 (because ID=5 preferred Samsung 2 time, and Sony only 1) 
# LG - 0 (because ID=8 preferred Apple more recently) 

What is the most efficient way with PySpark to achieve this result?

IIUC, you want to pick the most frequent product for each ID , breaking ties using the most recent Date

So first, we can get the count for each product/ID pair using:

import pyspark.sql.functions as f
from pyspark.sql import Window

df = df.select(
    'ID',
    'Product',
    'Date', 
    f.count('Product').over(Window.partitionBy('ID', 'Product')).alias('count')
)
df.show()
#+---+-------+----------+-----+
#| ID|Product|      Date|count|
#+---+-------+----------+-----+
#|  5|   Sony|2018-02-21|    1|
#|  8|     LG|2018-02-20|    1|
#|  8|  Apple|2018-02-22|    1|
#|  5|Samsung|2018-02-23|    2|
#|  5|Samsung|2018-02-20|    2|
#+---+-------+----------+-----+

Now you can use a Window to rank each product for each ID. We can use pyspark.sql.functions.desc() to sort by count and Date descending. If the row_number() is equal to 1, that means that row is first.

w = Window.partitionBy('ID').orderBy(f.desc('count'), f.desc('Date'))
df = df.select(
    'Product',
    (f.row_number().over(w) == 1).cast("int").alias('Counter')
)
df.show()
#+-------+-------+
#|Product|Counter|
#+-------+-------+
#|Samsung|      1|
#|Samsung|      0|
#|   Sony|      0|
#|  Apple|      1|
#|     LG|      0|
#+-------+-------+

Finally groupBy() the Product and pick the value for maximum value for Counter :

df.groupBy('Product').agg(f.max('Counter').alias('Counter')).show()
#+-------+-------+
#|Product|Counter|
#+-------+-------+
#|   Sony|      0|
#|Samsung|      1|
#|     LG|      0|
#|  Apple|      1|
#+-------+-------+

Update

Here's a little bit of a simpler way:

w = Window.partitionBy('ID').orderBy(f.desc('count'), f.desc('Date'))
df.groupBy('ID', 'Product')\
    .agg(f.max('Date').alias('Date'), f.count('Product').alias('Count'))\
    .select('Product', (f.row_number().over(w) == 1).cast("int").alias('Counter'))\
    .show()
#+-------+-------+
#|Product|Counter|
#+-------+-------+
#|Samsung|      1|
#|   Sony|      0|
#|  Apple|      1|
#|     LG|      0|
#+-------+-------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM