简体   繁体   English

如何在PySpark中为一个组遍历Dataframe / RDD的每一行?

[英]How to iterate over each row of an Dataframe / RDD in PySpark for a group.?

I want to set the value of column based on the value of that column in the previous row for a group. 我想基于组的上一行中该列的值来设置列的值。 Then this updated value will be used in the next row. 然后,此更新的值将在下一行中使用。

I have the following dataframe 我有以下数据框

id | start_date|sort_date | A | B |
-----------------------------------
1 | 1/1/2017 | 31-01-2015 | 1 | 0 | 
1 | 1/1/2017 | 28-02-2015 | 0 | 0 | 
1 | 1/1/2017 | 31-03-2015 | 1 | 0 | 
1 | 1/1/2017 | 30-04-2015 | 1 | 0 | 
1 | 1/1/2017 | 31-05-2015 | 1 | 0 | 
1 | 1/1/2017 | 30-06-2015 | 1 | 0 | 
1 | 1/1/2017 | 31-07-2015 | 1 | 0 | 
1 | 1/1/2017 | 31-08-2015 | 1 | 0 | 
1 | 1/1/2017 | 30-09-2015 | 0 | 0 | 
2 | 1/1/2017 | 31-10-2015 | 1 | 0 | 
2 | 1/1/2017 | 30-11-2015 | 0 | 0 | 
2 | 1/1/2017 | 31-12-2015 | 1 | 0 | 
2 | 1/1/2017 | 31-01-2016 | 1 | 0 | 
2 | 1/1/2017 | 28-02-2016 | 1 | 0 | 
2 | 1/1/2017 | 31-03-2016 | 1 | 0 | 
2 | 1/1/2017 | 30-04-2016 | 1 | 0 | 
2 | 1/1/2017 | 31-05-2016 | 1 | 0 | 
2 | 1/1/2017 | 30-06-2016 | 0 | 0 | 

Output : 输出:

id | start_date|sort_date | A | B | C
---------------------------------------
1 | 1/1/2017 | 31-01-2015 | 1 | 0 | 1
1 | 1/1/2017 | 28-02-2015 | 0 | 0 | 0
1 | 1/1/2017 | 31-03-2015 | 1 | 0 | 1
1 | 1/1/2017 | 30-04-2015 | 1 | 0 | 2
1 | 1/1/2017 | 31-05-2015 | 1 | 0 | 3
1 | 1/1/2017 | 30-06-2015 | 1 | 0 | 4
1 | 1/1/2017 | 31-07-2015 | 1 | 0 | 5
1 | 1/1/2017 | 31-08-2015 | 1 | 0 | 6
1 | 1/1/2017 | 30-09-2015 | 0 | 0 | 0
2 | 1/1/2017 | 31-10-2015 | 1 | 0 | 1
2 | 1/1/2017 | 30-11-2015 | 0 | 0 | 0
2 | 1/1/2017 | 31-12-2015 | 1 | 0 | 1
2 | 1/1/2017 | 31-01-2016 | 1 | 0 | 2
2 | 1/1/2017 | 28-02-2016 | 1 | 0 | 3
2 | 1/1/2017 | 31-03-2016 | 1 | 0 | 4
2 | 1/1/2017 | 30-04-2016 | 1 | 0 | 5
2 | 1/1/2017 | 31-05-2016 | 1 | 0 | 6
2 | 1/1/2017 | 30-06-2016 | 0 | 0 | 0

Group is of id and date 群组的编号和日期

Column C is to derived based on column A and B. 列C将基于列A和B派生。

If A == 1 and B == 0 then C is derived C from previous row + 1. 如果A == 1且B == 0,则C从上一行+ 1导出C。
There are some other conditions as well but I am struggling with this part. 还有其他一些条件,但是我正在为这一部分苦苦挣扎。

Assuming we have a column sort_date in dataframe. 假设我们在数据框中有一个sort_date列。

I tried the following query : 我尝试了以下查询:

SELECT
id,
date,
sort_date,
lag(A) OVER (PARTITION BY  id, date ORDER BY sort_date) as prev,
CASE
   WHEN A=1 AND B= 0  THEN 1
   WHEN  A=1 AND B> 0 THEN prev +1
   ELSE 0
 END AS A
FROM
Table

This Is what I did for UDAF 这就是我为UDAF做的

val myFunc = new MyUDAF
val w = Window.partitionBy(col("ID"), col("START_DATE")).orderBy(col("SORT_DATE"))
val df = df.withColumn("C", myFunc(col("START_DATE"), col("X"),
  col("Y"), col("A"),
  col("B")).over(w))

PS : I am using Spark 1.6 PS:我正在使用Spark 1.6

First define a window: 首先定义一个窗口:

import org.apache.spark.sql.expressions.Window
val winspec = Window.partitionBy("id","start_date").orderBy("sort_date")

Next create a UDAF which recieves A and B and basically calculates C by starting with 0, changing to 0 whenever the condition appears (A=1,B=0) and increasing by 1 any other time. 接下来,创建一个UDAF,它接收A和B并基本上从0开始计算C,只要条件出现(A = 1,B = 0)便更改为0,然后再增加1。 To see how to write a UDAF see examples in here , here and here 要了解如何编写UDAF,请参见此处此处此处的示例

EDIT Here is a sample implementation of the UDAF (not really tested so there may be typos): 编辑这是UDAF的示例实现(未经实际测试,因此可能有错别字):

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer,UserDefinedAggregateFunction}
 import org.apache.spark.sql.types._

 class myFunc() extends UserDefinedAggregateFunction {

  // Input Data Type Schema
  def inputSchema: StructType = StructType(Array(StructField("A", IntegerType), StructField("A", IntegerType)))

   // Intermediate Schema
  def bufferSchema = StructType(Array(StructField("C", IntegerType)))

  // Returned Data Type .
  def dataType: DataType = IntegerType

  // Self-explaining
  def deterministic = true

  // This function is called whenever key changes
  def initialize(buffer: MutableAggregationBuffer) = {
    buffer(0) = 0 // set number of items to 0
  }

  // Iterate over each entry of a group
  def update(buffer: MutableAggregationBuffer, input: Row) = {
    buffer(0) = if (input.getInt(0) == 1 && input.getInt(1) == 0) buffer.getInt(0) + 1 else 0
  }

  // Merge two partial aggregates - doesn't really matter because the window will make sure the buffer remains in a
  // single partition
  def merge(buffer1: MutableAggregationBuffer, buffer2: Row) = {
    buffer1(0) = buffer1.getInt(0) + buffer2.getInt(0)
  }

  // Called after all the entries are exhausted.
  def evaluate(buffer: Row) = {
    buffer.getInt(0)
  }

}

Last apply it to your dataframe. 最后将其应用于您的数据框。 Let's assume you named your UDAF myFunc: 假设您将UDAF myFunc命名为:

val f = new myFunc()
val newDF = df.withColumn("newC", f($"A",$"B").over(winspec))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM