简体   繁体   中英

Dataframe transformation

I am new to Spark and Scala. I have a dataframe that has huge amounts of data. The schema is like this one.

Lets call this dataframe empDF :

       id   name      emp_id  code  date
       1    Andrew    D01     C101  2012-06-14
       2    James     D02     C101  2013-02-26
       3    James     D02     C102  2013-12-29
       4    James     D02     C101  2010-09-27
       5    Andrew    D01     C101  2013-10-12
       6    Andrew    D01     C102  2011-10-13

I read this data from the database as a DataFrame[Row] object. Now I have to perform the following steps:

For each row with code C101 level greater than 1 must be set and for other codes level should be 0.If there is no previous record, the level is set to 1 . If there is a previous record that is two or more years older than that record, level is set to 2 . After this step dataframe should look like this

       id   name      emp_id  code  date         level
       1    Andrew    D01     C101  2012-06-14     2
       2    James     D02     C101  2013-02-26     2
       3    James     D02     C102  2012-12-29     0
       4    James     D02     C101  2010-09-27     1
       5    Andrew    D01     C101  2009-10-12     1
       6    Andrew    D01     C102  2010-10-13     0

The first and second row has level 2 because there is an older record of this employee and the date difference between the two rows is more than two years. The row with level 1 is because there is no record with a previous date and the row with level '0' is because we have marked all codes to 0 level other then C101

Now for rows with level 2 we have to check that if code C102 is apply on those employee within last year if apply then set level to 3 otherwise do not change the level. And in the final result dataframe all rows other than code C101 should be dropped.

After above two steps the resulting dataframe should look like this:

     id name    emp_id  code  date       level
     1  Andrew  D01     C101  2012-06-14   2
     2  James   D02     C101  2013-02-26   3
     4  James   D02     C101  2010-09-27   1
     5  Andrew  2013    C101  2013-10-12   1

Notice that the first row has level 2 because this employee has no C102 within last year however second row have C102 within last year. How can I do that in Scala using the dataframe api, and functions like map , flatmap , reduce , etc.?

You can use window functions :

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._

val window = Window.partitionBy("name").orderBy("date")
val lagCol = lag(col("date"), 1).over(window)
val diff = datediff(col("date"), col("previousDate"))
val level = when(
  col("previousDate").isNull || (diff <= 730), 1
).otherwise(2)

val newDF = empDF
  .where(col("code") === "C101")
  .withColumn("previousDate", lagCol)
  .withColumn("level", level)
  .drop("previousDate")

newDF.orderBy("id").show

+---+------+------+----+----------+-----+
| id|  name|emp_id|code|      date|level|
+---+------+------+----+----------+-----+
|  1|Andrew|   D01|C101|2012-06-14|    1|
|  2|James |   D02|C101|2013-02-26|    2|
|  4|James |   D02|C101|2010-09-27|    1|
+---+------+------+----+----------+-----+
// Input data
val df = {
    import org.apache.spark.sql._
    import org.apache.spark.sql.types._
    import scala.collection.JavaConverters._
    import java.time.LocalDate

    val simpleSchema = StructType(
        StructField("id", IntegerType) ::
        StructField("name", StringType) ::
        StructField("emp_id", StringType) ::
        StructField("code", StringType) ::
        StructField("date", DateType) :: Nil)

    val data = List(
        Row(1, "Andrew", "D01", "C101", java.sql.Date.valueOf(LocalDate.of(2012, 6, 14))),
        Row(2, "James", "D02", "C101", java.sql.Date.valueOf(LocalDate.of(2013, 2, 26))),
        Row(3, "James", "D02", "C102", java.sql.Date.valueOf(LocalDate.of(2013, 12, 29))),
        Row(4, "James", "D02", "C101", java.sql.Date.valueOf(LocalDate.of(2010, 9, 27)))
    )    

    spark.createDataFrame(data.asJava, simpleSchema)
}
df.show()
// Filter and level calculation.
val df2 = df.filter(col("code") === "C101").
    withColumn("level", when(datediff(col("date"), min(col("date")).over(Window.partitionBy("emp_id"))) >= 365 * 2, 2).otherwise(1))
df2.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM