[英]pyspark: Using the window() function and compare strings
我有以下测试数据。 有了这个数据,如果必须在 pyspark 中编写如下规则(实际数据确实很大):
import pandas as pd
import datetime
data = {'date': ['2014-01-01', '2014-01-02', '2014-01-03', '2014-01-04', '2014-01-05', '2014-01-06'],
'customerid': [1, 2, 2, 3, 4, 3], 'productids': ['A;B', 'D;E', 'H;X', 'P;Q;G', 'S;T;U', 'C;G']}
data = pd.DataFrame(data)
data['date'] = pd.to_datetime(data['date'])
规则如下:
“对于一个客户 ID,在 y 天内,购物篮中有超过 x 笔相同产品的交易。”
在我的示例中,我将 go 返回 x=2 天并检查至少 y=1 匹配的客户 ID。 结果应该是这样的:
date |customerid|result 2014-01-01|1 |0 2014-01-02|2 |0 2014-01-03|2 |0 2014-01-04|3 |0 2014-01-05|4 |0 2014-01-06|3 |1
有一段时间 window 2 天,仅 2014-01-06 我们有相同的客户 ID 出现的情况(2014 年 1 月 4 日的客户 ID 3)并且还有一个匹配的产品(G)。
我知道我可以像这样使用时间 window :
win = Window().partitionBy('customerid').orderBy((F.col('date')).cast("long")).rangeBetween(
-(2*86400), Window.currentRow)
不幸的是,我现在没有进一步的进展。 我也绝对不知道如何比较 productid,因为它们始终只能作为长字符串使用。
谢谢!
这将适用于spark2.4+ (因为array_distinct
)。 只要您的productids
由;
分隔 ,我们可以在该分隔符上split
以创建一个列表。 使用您已经拥有的window
,我们collect_list
,将其flatten
,然后查看我们有多少重复项。 重复的数量是您想要的result
。
from pyspark.sql import functions as F
from pyspark.sql.window import Window
df=spark.createDataFrame(data)
w=Window().partitionBy("customerid").orderBy(F.col("date").cast("long")).rangeBetween(-86400*2,0)
df.withColumn("productids", F.split("productids", "\;"))\
.withColumn("products", F.flatten(F.collect_list("productids").over(w)))\
.withColumn("result", F.size("products") - F.size(F.array_distinct("products")))\
.orderBy(F.col("date")).drop("productids","products").show()
+-------------------+----------+------+
| date|customerid|result|
+-------------------+----------+------+
|2014-01-01 00:00:00| 1| 0|
|2014-01-02 00:00:00| 2| 0|
|2014-01-03 00:00:00| 2| 0|
|2014-01-04 00:00:00| 3| 0|
|2014-01-05 00:00:00| 4| 0|
|2014-01-06 00:00:00| 3| 1|
+-------------------+----------+------+
更新:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
w=Window().partitionBy("customerid").orderBy(F.col("date").cast("long")).rangeBetween(-86400*2,0)
df.withColumn("productids", F.array_distinct(F.split("productids", "\;")))\
.withColumn("products", F.flatten((F.collect_list("productids").over(w))))\
.withColumn("result", F.when(F.size("products")!=F.size(F.array_distinct("products")),F.lit(1)).otherwise(F.lit(0)))\
.drop("productids","products").orderBy("date").show()
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.