简体   繁体   中英

Spark-SQL Window functions on Dataframe - Finding first timestamp in a group

I have below dataframe (say UserData).

uid region  timestamp
a   1   1
a   1   2
a   1   3
a   1   4
a   2   5
a   2   6
a   2   7
a   3   8
a   4   9
a   4   10
a   4   11
a   4   12
a   1   13
a   1   14
a   3   15
a   3   16
a   5   17
a   5   18
a   5   19
a   5   20

This data is nothing but user (uid) travelling across different regions (region) at different time (timestamp). Presently, timestamp is shown as 'int' for simplicity. Note that above dataframe will not be necessarily in increasing order of timestamp. Also, there may be some rows in between from different users. I have shown dataframe for single user only in monotonically incrementing order of timestamp for simplicity.

My goal is - to find User 'a' spent how much time in each region and in what order? So My final expected output looks like

uid region  regionTimeStart regionTimeEnd
a   1   1   5
a   2   5   8
a   3   8   9
a   4   9   13
a   1   13  15
a   3   15  17
a   5   17  20

Based on my findings, Spark SQL Window functions can be used for this purpose. I have tried below things,

val w = Window
  .partitionBy("region")
  .partitionBy("uid")
  .orderBy("timestamp")

val resultDF = UserData.select(
  UserData("uid"), UserData("timestamp"),
  UserData("region"), rank().over(w).as("Rank"))

But here onwards, I am not sure on how to get regionTimeStart and regionTimeEnd columns. regionTimeEnd column is nothing but 'lead' of regionTimeStart except the last entry in group.

I see Aggregate operations have 'first' and 'last' functions but for that I need to group data based on ('uid','region') which spoils monotonically increasing order of path traversed ie at time 13,14 user has come back to region '1' and I want that to be retained instead of clubbing it with initial region '1' at time 1.

It would be very helpful if anyone one can guide me. I am new to Spark and I have better understanding of Scala Spark APIs compared to Python/JAVA Spark APIs.

Window functions are indeed useful although your approach can work only if you assume that user visits given region only once. Also window definition you use is incorrect - multiple calls to partitionBy simply return new objects with different window definitions. If you want to partition by multiple columns you should pass them in a single call ( .partitionBy("region", "uid") ).

Lets start with marking continuous visits in each region:

import org.apache.spark.sql.functions.{lag, sum, not}
import org.apache.spark.sql.expressions.Window 

val w = Window.partitionBy($"uid").orderBy($"timestamp")

val change = (not(lag($"region", 1).over(w) <=> $"region")).cast("int")
val ind = sum(change).over(w)

val dfWithInd = df.withColumn("ind", ind)

Next you we simply aggregate over the groups and find leads:

import org.apache.spark.sql.functions.{lead, coalesce}

val regionTimeEnd = coalesce(lead($"timestamp", 1).over(w), $"max_")

val result = dfWithInd
  .groupBy($"uid", $"region", $"ind")
  .agg(min($"timestamp").alias("timestamp"), max($"timestamp").alias("max_"))
  .drop("ind")
  .withColumn("regionTimeEnd", regionTimeEnd)
  .withColumnRenamed("timestamp", "regionTimeStart")
  .drop("max_")

result.show

// +---+------+---------------+-------------+
// |uid|region|regionTimeStart|regionTimeEnd|
// +---+------+---------------+-------------+
// |  a|     1|              1|            5|
// |  a|     2|              5|            8|
// |  a|     3|              8|            9|
// |  a|     4|              9|           13|
// |  a|     1|             13|           15|
// |  a|     3|             15|           17|
// |  a|     5|             17|           20|
// +---+------+---------------+-------------+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM