[英]Rewrite UDF to pandas UDF Pyspark
I have a dataframe:我有一个 dataframe:
import pyspark.sql.functions as F
sdf1 = spark.createDataFrame(
[
(2022, 1, ["apple", "edible"]),
(2022, 1, ["edible", "fruit"]),
(2022, 1, ["orange", "sweet"]),
(2022, 4, ["flowering ", "plant"]),
(2022, 3, ["green", "kiwi"]),
(2022, 3, ["kiwi", "fruit"]),
(2022, 3, ["fruit", "popular"]),
(2022, 3, ["yellow", "lemon"]),
],
[
"year",
"id",
"bigram",
],
)
sdf1.show(truncate=False)
+----+---+-------------------+
|year|id |bigram |
+----+---+-------------------+
|2022|1 |[apple, edible] |
|2022|1 |[edible, fruit] |
|2022|1 |[orange, sweet] |
|2022|4 |[flowering , plant]|
|2022|3 |[green, kiwi] |
|2022|3 |[kiwi, fruit] |
|2022|3 |[fruit, popular] |
|2022|3 |[yellow, lemon] |
+----+---+-------------------+
And i wrote a function that returns bigrams with the same last words in n-grams.I apply this function separately to the column.我写了一个 function,它返回 n-grams 中最后一个词相同的二元组。我将这个 function 分别应用于该列。
from networkx import DiGraph, dfs_labeled_edges
# Grouping
sdf = (
sdf1.groupby("year", "id")
.agg(F.collect_set("bigram").alias("collect_bigramm"))
.withColumn("size", F.size("collect_bigramm"))
)
data_collect = sdf.collect()
@udf(returnType=ArrayType(StringType()))
def myfunc(lst):
graph = DiGraph()
for row in data_collect:
if row["size"] > 1:
for i, lst1 in enumerate(lst):
while i < len(lst) - 1:
lst2 = lst[i + 1]
if lst1[0] == lst2[1]:
graph.add_edge(lst2[0], lst2[1])
graph.add_edge(lst1[0], lst1[1])
elif lst1[1] == lst2[0]:
graph.add_edge(lst1[0], lst1[1])
graph.add_edge(lst2[0], lst2[1])
i = i + 1
gen = dfs_labeled_edges(graph)
lst_tmp = []
lst_res = []
f = 0
for g in list(gen):
if (g[2] == "forward") and (g[0] != g[1]):
f = 1
lst_tmp.append(g[0])
lst_tmp.append(g[1])
if g[2] == "nontree":
continue
if g[2] == "reverse":
if f == 1:
lst_res.append(lst_tmp.copy())
f = 0
if g[0] in lst_tmp:
lst_tmp.remove(g[0])
if g[1] in lst_tmp:
lst_tmp.remove(g[1])
if lst_res != []:
lst_res = [
ii for n, ii in enumerate(lst_res[0]) if ii not in lst_res[0][:n]
]
if lst_res == []:
lst_res = None
return lst_res
sdf_new = sdf.withColumn("new_col", myfunc(F.col("collect_bigramm")))
sdf_new.show(truncate=False)
Output: Output:
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|year|id |collect_bigramm |size|new_col |
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|2022|4 |[[flowering , plant]] |1 |null |
|2022|1 |[[edible, fruit], [orange, sweet], [apple, edible]] |3 |[apple, edible, fruit] |
|2022|3 |[[yellow, lemon], [green, kiwi], [kiwi, fruit], [fruit, popular]]|4 |[green, kiwi, fruit, popular]|
+----+---+-----------------------------------------------------------------+----+-----------------------------+
But now i want to use the pandas udf.但现在我想使用 pandas udf。 I would like to first groupby and get the
collect_bigramm
column in the function. And thus leave all the columns in the dataframe, but also add a new one, which is the lst_res
array in the function.我想首先groupby并获取function中的
collect_bigramm
列。因此保留dataframe中的所有列,但还要添加一个新列,即function中的lst_res
数组。
schema2 = StructType(
[
StructField("year", IntegerType(), True),
StructField("id", IntegerType(), True),
StructField("bigram", ArrayType(StringType(), True), True),
StructField("new_col", ArrayType(StringType(), True), True),
StructField("collect_bigramm", ArrayType(ArrayType(StringType(), True), True), True),
]
)
@pandas_udf(schema2, functionType=PandasUDFType.GROUPED_MAP)
def myfunc(df):
graph = DiGraph()
for index, row in df.iterrows():
# Instead of the variable lst, i need to insert a column sdf['collect_bigramm']
...
return df
sdf_new = sdf.groupby(["year", "id"]).apply(myfunc)
You don't want to run groupBy
twice (one for sdf1
and one for pandas_udf
), it'd simply kill the idea of "grouping a list of records then vectorize it then send to worker" of pandas_udf
.您不想运行
groupBy
两次(一次用于sdf1
,一次用于pandas_udf
),它只会扼杀 pandas_udf 的“对记录列表进行分组,然后对其进行矢量化,然后发送给工作人员”的pandas_udf
。 You'd want to do something like this instead sdf1.groupby("year", "id").applyInPandas(myfunc, schema2)
你想做这样的事情而不是
sdf1.groupby("year", "id").applyInPandas(myfunc, schema2)
Your UDF is now a "Panda UDF", which is literally just a Python function, take one Pandas DF and return another Pandas UDF.你的 UDF 现在是一个“熊猫 UDF”,它实际上只是一个 Python function,取一个 Pandas DF 并返回另一个 Pandas UDF。 With that meaning, you can even run that function without Spark.
有了这个意思,您甚至可以在没有Spark 的情况下运行 function。 The trick here is just how to form your dataframe to feed with what you need.
这里的诀窍就是如何形成你的 dataframe 来满足你的需要。 Check the running code below, I kept most of your.networkx code, just fix a little from the input and output.
检查下面的运行代码,我保留了你的大部分 .networkx 代码,只是从输入和 output 中修复了一点。
def myfunc(pdf):
pdf = (pdf
.groupby(['year', 'id'])['bigram']
.agg(list=list, len=len) # you might want to fix the list here to set
.reset_index()
.rename(columns={
'list': 'collect_bigram',
'len': 'size',
})
)
graph = DiGraph()
if pdf['size'][0] > 1:
lst = pdf['collect_bigram'][0]
for i, lst1 in enumerate(lst):
... # same as original code
if lst_res == []:
lst_res = None
pdf['new_col'] = [lst_res]
else:
pdf['new_col'] = None
return pdf
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.