简体   繁体   中英

Cumulative sum of a partition in PySpark

I need to create a column with a group number that increments based on the values in the colmn TRUE. I can partition by ID so I'm thinking this would reset the increment when the ID changes, which I want to do. Within ID, I want to increment the group number whenever TRUE is not equal to 1. When TRUE = 1 I want it to keep the number the same as the last. This is subset of my current ID and TRUE columns, and GROUP is shown as desired. I also have columns LATITUDE and LONGITUDE that I use in my sort.

ID    TRUE  GROUP
3828    0   1
3828    0   2
3828    1   2
3828    1   2
3828    1   2
4529    0   1
4529    1   1
4529    0   2
4529    1   2
4529    0   3
4529    0   4
4529    1   4
4529    0   5
4529    1   5
4529    1   5

I was hoping to do something like below, but this is giving me all 0s

trip.registerTempTable("trip_temp")
trip2 = sqlContext.sql('select *, sum(cast(TRUE = 0 as int)) over(partition by ID order by ID, LATITUDE, LONGITUDE) as GROUP from trip_temp')

I know the question is quite old. Just wanted to share for those you might be looking for an optimal way.

from pyspark.sql.window import *
import sys

cumSumPartition = Window.partitionBy(['col1','col2','col3','col4']).orderBy("col3").rowsBetween(-sys.maxsize -1,0)

temp = temp.withColumn("col5",sum(temp.col5).over(cumSumPartition))

Never use restricted keywords as column names. Even if this may work in some systems it is error prone, may stop working if you change parser and generally speaking is really bad practice. TRUE is boolean literal and will be never equal to 0 (with implicit cast it is equivalent to TRUE IS NOT TRUE )

spark.createDataFrame(
    [(3828, 0, 1), (3828, 1, 2)], ("ID", "TRUE", "GROUP")
).createOrReplaceTempView("trip_temp")

spark.sql("SELECT TRUE = 0 AS foo FROM trip_temp LIMIT 2").show()

// +-----+
// |  foo|
// +-----+
// |false|
// |false|
// +-----+

If you really want to make it work use backticks:

spark.sql("SELECT `TRUE` = 0 AS foo FROM trip_temp LIMIT 2").show()

// +-----+
// |  foo|
// +-----+
// | true|
// |false|
// +-----+

but please don't.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM