简体   繁体   English

PySpark - 将单个整数列表与列表列进行比较

[英]PySpark - compare single list of integers to column of lists

I'm trying to check which entries in a spark dataframe (column with lists) contain the largest quantity of values from a given list.我正在尝试检查 spark 数据框中的哪些条目(带列表的列)包含给定列表中的最大数量的值。

The best approach I've came up with is iterating over a dataframe with rdd.foreach() and comparing a given list to every entry using python's set1.intersection(set2) .我想出的最好方法是使用rdd.foreach()迭代数据帧,并使用 python 的set1.intersection(set2)将给定列表与每个条目进行比较。

My question is does spark have any built-in functionality for this so iterating with .foreach could be avoided?我的问题是 spark 是否有任何内置功能,因此可以避免使用.foreach进行迭代?

Thanks for any help!感谢您的帮助!

PS my dataframe looks like this: PS我的数据框看起来像这样:

+-------------+---------------------+                                           
|   cardnumber|collect_list(article)|
+-------------+---------------------+
|2310000000855| [12480, 49627, 80...|
|2310000008455| [35531, 22564, 15...|
|2310000011462| [117112, 156087, ...|
+-------------+---------------------+

And I'm trying to find entries with the most intersections in the second column with a given list of articles, eg [151574, 87239, 117908, 162475, 48599]我正在尝试使用给定的文章列表在第二列中找到交叉点最多的条目,例如[151574, 87239, 117908, 162475, 48599]

You can try the same set operation in dataframe instead of using rdd.foreach:您可以在数据框中尝试相同的设置操作,而不是使用 rdd.foreach:

from pyspark.sql.functions import udf, li, col
my_udf=udf(lambda A,B: list(set(A).intersection(set(B))))
df=df.withColumn('intersect_value', my_udf('A', 'B'))

You can use the len function to get the size of intersect list in the UDF itself and perform the operation you want from this dataframe.您可以使用 len 函数获取 UDF 本身中相交列表的大小,并从此数据帧执行您想要的操作。

The only alternative here is udf , but it won't be much of a difference.这里唯一的选择是udf ,但不会有太大区别。

from pyspark.sql.functions import udf, li, col

def intersect(xs):
    xs = set(xs)
    @udf("array<long>")
    def _(ys):
        return list(xs.intersection(ys))
    return _

It can be applied as:它可以应用为:

a_list = [1, 4, 6]

df = spark.createDataFrame([
    (1, [3, 4, 8]), (2, [7, 2, 6])
], ("id", "articles"))

df.withColumn("intersect", intersect(a_list)("articles")).show()

# +---+---------+---------+
# | id| articles|intersect|
# +---+---------+---------+
# |  1|[3, 4, 8]|      [4]|
# |  2|[7, 2, 6]|      [6]|
# +---+---------+---------+

Based on the names, it looks like you use collect_list so your data looks probably like this:根据名称,您似乎使用了collect_list因此您的数据可能如下所示:

df_long = spark.createDataFrame([
    (1, 3),(1, 4), (1, 8), (2, 7), (2, 7), (2, 6)
], ("id", "articles"))

In that case problem is simpler.在这种情况下,问题更简单。 Join加入

lookup = spark.createDataFrame(a_list, "long").toDF("articles")

joined = lookup.join(df_long, ["articles"])

and aggregate the result:并汇总结果:

joined.groupBy("id").count().show()
# +---+-----+                                                                     
# | id|count|
# +---+-----+
# |  1|    1|
# |  2|    1|
# +---+-----+


joined.groupBy("id").agg(collect_list("articles")).show()
# +---+----------------------+                                                    
# | id|collect_list(articles)|
# +---+----------------------+
# |  1|                   [4]|
# |  2|                   [6]|
# +---+----------------------+

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM