[英]Perform a user defined function on a column of a large pyspark dataframe based on some columns of another pyspark dataframe on databricks
我的问题与我之前在如何有效地加入大型 pyspark 数据帧和小型 python 列表以获得数据块上的一些 NLP 结果相关。
我已经解决了其中的一部分,现在又遇到了另一个问题。
我有一个小的 pyspark 数据框,如:
df1:
+-----+--------------------------------------------------+--------------------------------------------------+--------------------------------------------------+
|topic| termIndices| termWeights| terms|
+-----+--------------------------------------------------+--------------------------------------------------+--------------------------------------------------+
| 0| [3, 155, 108, 67, 239, 4, 72, 326, 128, 189]|[0.023463344607734377, 0.011772322769900843, 0....|[cell, apoptosis, uptake, loss, transcription, ...|
| 1| [16, 8, 161, 86, 368, 153, 18, 214, 21, 222]|[0.013057307487199429, 0.011453455929929763, 0....|[therapy, cancer, diet, lung, marker, sensitivi...|
| 2| [0, 1, 124, 29, 7, 2, 84, 299, 22, 90]|[0.03979063871841061, 0.026593954837078836, 0.0...|[group, expression, performance, use, disease, ...|
| 3| [204, 146, 74, 240, 152, 384, 55, 250, 238, 92]|[0.009305626056223443, 0.008840730657888991, 0....|[pattern, chemotherapy, mass, the amount, targe...|
它只有不到 100 行,而且非常小。 每个术语在“termWeights”列中都有一个 termWeight 值。
我有另一个大型 pyspark 数据框(50+ GB),例如:
df2:
+------+--------------------------------------------------+
|r_id| tokens|
+------+--------------------------------------------------+
| 0|[The human KCNJ9, Kir, GIRK3, member, potassium...|
| 1|[BACKGROUND, the treatment, breast, cancer, the...|
| 2|[OBJECTIVE, the relationship, preoperative atri...|
对于 df2 中的每一行,我需要在 df1 中找到所有主题中具有最高 termWeights 的最佳匹配项。
最后,我需要一个 df 像
r_id tokens topic (the topic in df1 that has the highest sum of termWeights among all topics)
我已经定义了一个 UDF(基于 df2),但它无法访问 df1 的列。 我正在考虑如何对 df1 和 df2 使用“交叉连接”,但我不需要将 df2 的每一行与 df1 的每一行连接起来。 我只需要保留 df2 的所有列,并根据每个 df1 主题与每个 df2 行的术语的匹配项,添加具有最高 termWeights 总和的“主题”列。
我不确定如何通过 pyspark.sql.functions.udf 实现这个逻辑。
IIUC,您可以尝试如下操作(我将处理流程分为4个步骤,需要Spark 2.4+ ):
步骤 1:将所有 df2.tokens 转换为小写,以便我们可以进行文本比较:
from pyspark.sql.functions import expr, desc, row_number, broadcast
df2 = df2.withColumn('tokens', expr("transform(tokens, x -> lower(x))"))
步骤 2:使用arrays_overlap将 df2 与 df1 左连接
df3 = df2.join(broadcast(df1), expr("arrays_overlap(terms, tokens)"), "left")
Step-3:使用聚合函数从terms 、 termWeights和tokens计算matched_sum_of_weights
df4 = df3.selectExpr(
"r_id",
"tokens",
"topic",
"""
aggregate(
/* find all terms+termWeights which are shown in tokens array */
filter(arrays_zip(terms,termWeights), x -> array_contains(tokens, x.terms)),
0D,
/* get the sum of all termWeights from the matched terms */
(acc, y) -> acc + y.termWeights
) as matched_sum_of_weights
""")
步骤 4:对于每个 r_id,使用 Window 函数找到具有最高matched_sum_of_weights
的行,并且只保留row_number == 1
行
from pyspark.sql import Window
w1 = Window.partitionBy('r_id').orderBy(desc('matched_sum_of_weights'))
df_new = df4.withColumn('rn', row_number().over(w1)).filter('rn=1').drop('rn', 'matched_sum_of_weights')
替代方案:如果 df1 的大小不是很大,这可能会在没有 join/window.partition 等的情况下处理。下面的代码仅概述了您应该根据实际数据改进的想法:
from pyspark.sql.functions import expr, when, coalesce, array_contains, lit, struct
# create a dict from df1 with topic as key and list of termWeights+terms as value
d = df1.selectExpr("string(topic)", "arrays_zip(termWeights,terms) as terms").rdd.collectAsMap()
# ignore this if text comparison are case-sensitive, you might do the same to df1 as well
df2 = df2.withColumn('tokens', expr("transform(tokens, x -> lower(x))"))
# save the column names of the original df2
cols = df2.columns
# iterate through all items of d(or df1) and update df2 with new columns from each
# topic with the value a struct containing `sum_of_weights`, `topic` and `has_match`(if any terms is matched)
for x,y in d.items():
df2 = df2.withColumn(x,
struct(
sum([when(array_contains('tokens', t.terms), t.termWeights).otherwise(0) for t in y]).alias('sum_of_weights'),
lit(x).alias('topic'),
coalesce(*[when(array_contains('tokens', t.terms),1) for t in y]).isNotNull().alias('has_match')
)
)
# create a new array containing all new columns(topics), and find array_max
# from items with `has_match == true`, and then retrieve the `topic` field
df_new = df2.selectExpr(
*cols,
f"array_max(filter(array({','.join(map('`{}`'.format,d.keys()))}), x -> x.has_match)).topic as topic"
)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.