繁体   English   中英

pyspark:使用 window() function 并比较字符串

[英]pyspark: Using the window() function and compare strings

我有以下测试数据。 有了这个数据,如果必须在 pyspark 中编写如下规则(实际数据确实很大):

import pandas as pd
import datetime

data = {'date': ['2014-01-01', '2014-01-02', '2014-01-03', '2014-01-04', '2014-01-05', '2014-01-06'],
     'customerid': [1, 2, 2, 3, 4, 3], 'productids': ['A;B', 'D;E', 'H;X', 'P;Q;G', 'S;T;U', 'C;G']}
data = pd.DataFrame(data)
data['date'] = pd.to_datetime(data['date'])

规则如下:

“对于一个客户 ID,在 y 天内,购物篮中有超过 x 笔相同产品的交易。”

在我的示例中,我将 go 返回 x=2 天并检查至少 y=1 匹配的客户 ID。 结果应该是这样的:

date      |customerid|result
2014-01-01|1         |0     
2014-01-02|2         |0         
2014-01-03|2         |0         
2014-01-04|3         |0         
2014-01-05|4         |0         
2014-01-06|3         |1

有一段时间 window 2 天,仅 2014-01-06 我们有相同的客户 ID 出现的情况(2014 年 1 月 4 日的客户 ID 3)并且还有一个匹配的产品(G)。

我知道我可以像这样使用时间 window :

win = Window().partitionBy('customerid').orderBy((F.col('date')).cast("long")).rangeBetween(
        -(2*86400), Window.currentRow)

不幸的是,我现在没有进一步的进展。 我也绝对不知道如何比较 productid,因为它们始终只能作为长字符串使用。

谢谢!

这将适用于spark2.4+ (因为array_distinct )。 只要您的productids;分隔 ,我们可以在该分隔符上split以创建一个列表。 使用您已经拥有的window ,我们collect_list ,将其flatten ,然后查看我们有多少重复项 重复的数量是您想要的result

from pyspark.sql import functions as F
from pyspark.sql.window import Window    
df=spark.createDataFrame(data)
w=Window().partitionBy("customerid").orderBy(F.col("date").cast("long")).rangeBetween(-86400*2,0)
df.withColumn("productids", F.split("productids", "\;"))\
  .withColumn("products", F.flatten(F.collect_list("productids").over(w)))\
  .withColumn("result", F.size("products") - F.size(F.array_distinct("products")))\
  .orderBy(F.col("date")).drop("productids","products").show()


+-------------------+----------+------+
|               date|customerid|result|
+-------------------+----------+------+
|2014-01-01 00:00:00|         1|     0|
|2014-01-02 00:00:00|         2|     0|
|2014-01-03 00:00:00|         2|     0|
|2014-01-04 00:00:00|         3|     0|
|2014-01-05 00:00:00|         4|     0|
|2014-01-06 00:00:00|         3|     1|
+-------------------+----------+------+

更新:

from pyspark.sql import functions as F
from pyspark.sql.window import Window
w=Window().partitionBy("customerid").orderBy(F.col("date").cast("long")).rangeBetween(-86400*2,0)
df.withColumn("productids", F.array_distinct(F.split("productids", "\;")))\
  .withColumn("products", F.flatten((F.collect_list("productids").over(w))))\
  .withColumn("result", F.when(F.size("products")!=F.size(F.array_distinct("products")),F.lit(1)).otherwise(F.lit(0)))\
  .drop("productids","products").orderBy("date").show()

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM