[英]How can I use PySpark's Window function to model exponential decay?
I am trying to apply a PySpark Window function to do "exponential decay".我正在尝试应用 PySpark Window function 来执行“指数衰减”。 The formula is
公式是
todays_score = yesterdays_score * (weight) + todays_raw_score
So for example, suppose we have a dataframe that is ordered in days, and has a score of 1, every day:例如,假设我们有一个 dataframe 以天为单位订购,并且每天得分为 1:
+---+----+---------+
|day|user|raw_score|
+---+----+---------+
| 0| a| 1|
| 1| a| 1|
| 2| a| 1|
| 3| a| 1|
+---+----+---------+
If I were to calculate todays_score, it would look like this:如果我要计算 todays_score,它看起来像这样:
+---+----+---------+------------+
|day|user|raw_score|todays_score| # Here's the math:
+---+----+---------+------------+
| 0| a| 1| 1.0| (0 * .90) + 1
| 1| a| 1| 1.9| (1.0 * .90) + 1
| 2| a| 1| 2.71| (1.9 * .90) + 1
| 3| a| 1| 3.439| (2.71 * .90) + 1
+---+----+---------+------------+
I've tried using window functions;我试过使用 window 函数; however based on what I've seen, they can only use the "static values" from the original dataframe, not the values we just calculated.
但是根据我所见,他们只能使用原始 dataframe 中的“静态值”,而不是我们刚刚计算的值。 I've even tried creating a "dummy column" to start the process;
我什至尝试创建一个“虚拟列”来启动该过程; however that didn't work either.
但是那也不起作用。
My attempted code:我尝试的代码:
df = sqlContext.createDataFrame([
(0, 'a', 1),
(1, 'a', 1),
(2, 'a', 1),
(3, 'a', 1)],
['day', 'user', 'raw_score']
)
df.show()
# Create a "dummy column" (weighted score) so we can use it.
df2 = df.select('*', col('raw_score').alias('todays_score'))
df2.show()
w = Window.partitionBy('user')
df2.withColumn('todays_score',
F.lag(F.col('todays_score'), count=1, default=0).over(w.orderBy('day'))* 0.9 + F.col('raw_score')) \
.show()
The (undesired) output of this is:这个(不需要的)output 是:
+---+----+---------+------------+
|day|user|raw_score|todays_score|
+---+----+---------+------------+
| 0| a| 1| 1.0|
| 1| a| 1| 1.9|
| 2| a| 1| 1.9|
| 3| a| 1| 1.9|
+---+----+---------+------------+
which only takes the previous value * (.90), rather what was just calculated.它只取前一个值 * (.90),而不是刚刚计算的值。
How can I access the values that were just calculated by the window function?如何访问刚刚由 window function 计算的值?
For Spark2.4+
, you can use higher order functions transform
, aggregate
, filter
and arrays_zip
like this.对于
Spark2.4+
,您可以像这样使用高阶函数transform
、 aggregate
、 filter
和arrays_zip
。 It will work for any combination of raw_score and will be faster than pandas_udaf.它适用于 raw_score 的任何组合,并且比 pandas_udaf 更快。 (assuming data has been ordered by day per user as shown in sample)
(假设数据已按每个用户按天排序,如示例所示)
df.show() #sample dataframe
#+---+----+---------+
#|day|user|raw_score|
#+---+----+---------+
#| 0| a| 1|
#| 1| a| 1|
#| 2| a| 1|
#| 3| a| 1|
#+---+----+---------+
from pyspark.sql import functions as F
df\
.groupBy("user").agg(F.collect_list("raw_score").alias("raw_score"),F.collect_list("day").alias("day"))\
.withColumn("raw_score1", F.expr("""transform(raw_score,(x,i)-> struct(x as raw,i as index))"""))\
.withColumn("todays_score", F.expr("""transform(raw_score1, x-> aggregate(filter(raw_score1,z-> z.index<=x.index)\
,cast(0 as double),(acc,y)->(acc*0.9)+y.raw))"""))\
.withColumn("zip", F.explode(F.arrays_zip("day","raw_score","todays_score")))\
.select("user", "zip.*")\
.show(truncate=False)
#+----+---+---------+------------+
#|user|day|raw_score|todays_score|
#+----+---+---------+------------+
#|a |0 |1 |1.0 |
#|a |1 |1 |1.9 |
#|a |2 |1 |2.71 |
#|a |3 |1 |3.439 |
#+----+---+---------+------------+
UPDATE:
Assuming data has been ordered by day as shown in sample , you can use Pandas Grouped Map UDAF
like this:假设数据按天排序,如 sample 所示,您可以像这样使用
Pandas Grouped Map UDAF
:
import pandas as pd
from pyspark.sql import functions as F
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf(df.withColumn("raw_score", F.lit(1.2456)).schema, PandasUDFType.GROUPED_MAP)
def grouped_map(df):
for i in range(1,len(df)):
df.loc[i,'raw_score']=(df.loc[i-1,'raw_score'] * 0.9)+1
return df
df\
.groupby("user").apply(grouped_map).show()
#+---+----+---------+
#|day|user|raw_score|
#+---+----+---------+
#| 0| a| 1.0|
#| 1| a| 1.9|
#| 2| a| 2.71|
#| 3| a| 3.439|
#+---+----+---------+
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.