简体   繁体   中英

Can we dynamically retrieve previous row's value of a updating column in pyspark dataframe

So here is my exact problem statement. I have below mentioned dataframe.

+--------+-------+
|  START |  END  |
+--------+-------+
|   1    |   5   |
|   3    |   6   |
|   7    |   10  |
|   13   |   17  |
|   15   |   20  |
+--------+-------+

Imagine each row represents a line starting from START and ending at END on X-axis. And when we place them according to given data, we dont want the lines to intersect. So we stack them instead.

So First line remains same ie (1, 5)

as second line instersects with first we need to change its START and END values. So from (3, 6) it becomes (5, 8). (We cannot change length of line when we stack)

and third line (7, 10) becomes (8, 11) (as it intersects with previous (5, 8) line).

as fourth line doesnt intersect with updated third line, we dont change its values. So it remains (13, 17)

and last one which is (15, 20) becomes (17, 22).

So my final dataframe should be:

+--------+-------+
|  START |  END  |
+--------+-------+
|   1    |   5   |
|   5    |   8   |
|   8    |   11  |
|   13   |   17  |
|   17   |   22  |
+--------+-------+

You can consider that the initial dataframe is sorted by its START column.

Now this is a easy problem when we use loops but I wanted to do it in pyspark without using any loops. I am new to pyspark and coudn't find a good function so that I can achieve it without loops.

Now coming to my title, If I am able to retrieve previous row's END value dynamically (as it changes) and compare with current START value I can solve this problem. But I couldnt find anything which does that.

Here is my attempt to solve this problem without loops:

from pyspark import SparkContext 
from pyspark.sql import SQLContext
from pyspark.sql.functions import lag, col
from pyspark.sql.window import Window
from pyspark.sql import functions as F

sc = SparkContext('local', 'Example_2')
sqlcontext = SQLContext(sc)

df = sqlcontext.createDataFrame([(1, 5), (3, 6), (7, 10), (13, 17), (15, 20)], ['START', 'END'])

w = Window().orderBy('START').rangeBetween(Window.unboundedPreceding, -1)

# Updating 'END' column first
df = df.withColumn('END', F.when(
                F.last('END', True).over(w) > col('START'),
                col('END') + (F.last('END').over(w) - col('START'))
            ).otherwise(col('END')))

# Updating 'START' column
df = df.withColumn('START', F.when(
                    F.last('END', True).over(w) > col('START'),
                    F.last('END', True).over(w)
                ).otherwise(col('START')))

As F.last('END') doesn't give updated end value, the above code returns the following in which the third row is wrong.

+-----+---+
|START|END|
+-----+---+
|    1|  5|
|    5|  8|
|    8| 10|
|   13| 17|
|   17| 22|
+-----+---+

You need to use rowsBetween in the place of rangeBetween . So this will actually select the rows within the Window spec. See here https://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=exception#pyspark.sql.Window.rowsBetween

w = Window().orderBy('START').rowsBetween(Window.unboundedPreceding, -1)

I coudnt find a function which can retrieve previous row's value from an updating column. I guess you cannot do that in pyspark and only work-around solution is to use a udf and use loops inside it to solve the problem.

But I am able to solve the problem I mentioned in above question without using any udf and using some logic.

Below is my code. And my logic was designed more visually so not able to tell it here. If anyone is curious and not able to understand my logic after seeing the code, comment below so that I will try to explain it.

from pyspark import SparkContext 
from pyspark.sql import SQLContext
from pyspark.sql.functions import lag, col
from pyspark.sql.window import Window
from pyspark.sql import functions as F

sc = SparkContext('local', 'Example_2')
sqlcontext = SQLContext(sc)

df = sqlcontext.createDataFrame([(1, 5), (3, 6), (7, 10), (13, 17), (15, 20)], ['START', 'END'])

w = Window().orderBy('START')

df = df.withColumn('LENGTH', df.END - df.START)

df = df.withColumn('LENGTH_CUMU',
                   F.sum(df.LENGTH).over(w.rowsBetween(Window.unboundedPreceding, -1)))

df = df.withColumn('FIRST_START_DIFF',
                   df.START - F.first('START').over(w))

df = df.withColumn('REQ_SHIFT',
                   F.when(df.FIRST_START_DIFF > df.LENGTH_CUMU,
                          df.FIRST_START_DIFF - df.LENGTH_CUMU) \
                    .otherwise(0))

df = df.withColumn('REQ_SHIFT',
                   F.max('REQ_SHIFT').over(w.rowsBetween(Window.unboundedPreceding, 0)))

df = df.withColumn('START',
                   F.coalesce(df.START - df.FIRST_START_DIFF + df.LENGTH_CUMU + df.REQ_SHIFT, df.START))

df = df.withColumn('END', df.START + df.LENGTH)

df = df.select('START', 'END')

df.show()

Now it gives the right output which is:

+-----+---+
|START|END|
+-----+---+
|    1|  5|
|    5|  8|
|    8| 11|
|   13| 17|
|   17| 22|
+-----+---+

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM