繁体   English   中英

将 UDF 重写为 pandas UDF Pyspark

[英]Rewrite UDF to pandas UDF Pyspark

我有一个 dataframe:

import pyspark.sql.functions as F

sdf1 = spark.createDataFrame(
    [
        (2022, 1, ["apple", "edible"]),
        (2022, 1, ["edible", "fruit"]),
        (2022, 1, ["orange", "sweet"]),
        (2022, 4, ["flowering ", "plant"]),
        (2022, 3, ["green", "kiwi"]),
        (2022, 3, ["kiwi", "fruit"]),
        (2022, 3, ["fruit", "popular"]),
        (2022, 3, ["yellow", "lemon"]),
    ],
    [
        "year",
        "id",
        "bigram",
    ],
)
sdf1.show(truncate=False)

    +----+---+-------------------+
    |year|id |bigram             |
    +----+---+-------------------+
    |2022|1  |[apple, edible]    |
    |2022|1  |[edible, fruit]    |
    |2022|1  |[orange, sweet]    |
    |2022|4  |[flowering , plant]|
    |2022|3  |[green, kiwi]      |
    |2022|3  |[kiwi, fruit]      |
    |2022|3  |[fruit, popular]   |
    |2022|3  |[yellow, lemon]    |
    +----+---+-------------------+

我写了一个 function,它返回 n-grams 中最后一个词相同的二元组。我将这个 function 分别应用于该列。

from networkx import DiGraph, dfs_labeled_edges

# Grouping
sdf = (
    sdf1.groupby("year", "id")
    .agg(F.collect_set("bigram").alias("collect_bigramm"))
    .withColumn("size", F.size("collect_bigramm"))
)

data_collect = sdf.collect()


@udf(returnType=ArrayType(StringType()))
def myfunc(lst):
    graph = DiGraph()

    for row in data_collect:
        if row["size"] > 1:
            for i, lst1 in enumerate(lst):
                while i < len(lst) - 1:
                    lst2 = lst[i + 1]
                    if lst1[0] == lst2[1]:
                        graph.add_edge(lst2[0], lst2[1])
                        graph.add_edge(lst1[0], lst1[1])
                    elif lst1[1] == lst2[0]:
                        graph.add_edge(lst1[0], lst1[1])
                        graph.add_edge(lst2[0], lst2[1])
                    i = i + 1

            gen = dfs_labeled_edges(graph)
            lst_tmp = []
            lst_res = []
            f = 0
            for g in list(gen):
                if (g[2] == "forward") and (g[0] != g[1]):
                    f = 1
                    lst_tmp.append(g[0])
                    lst_tmp.append(g[1])

                if g[2] == "nontree":
                    continue
                if g[2] == "reverse":
                    if f == 1:
                        lst_res.append(lst_tmp.copy())
                    f = 0
                    if g[0] in lst_tmp:
                        lst_tmp.remove(g[0])
                    if g[1] in lst_tmp:
                        lst_tmp.remove(g[1])

            if lst_res != []:
                lst_res = [
                    ii for n, ii in enumerate(lst_res[0]) if ii not in lst_res[0][:n]
                ]
            if lst_res == []:
                lst_res = None
            return lst_res


sdf_new = sdf.withColumn("new_col", myfunc(F.col("collect_bigramm")))
sdf_new.show(truncate=False)

Output:

+----+---+-----------------------------------------------------------------+----+-----------------------------+
|year|id |collect_bigramm                                                          |size|new_col                      |
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|2022|4  |[[flowering , plant]]                                            |1   |null                         |
|2022|1  |[[edible, fruit], [orange, sweet], [apple, edible]]              |3   |[apple, edible, fruit]       |
|2022|3  |[[yellow, lemon], [green, kiwi], [kiwi, fruit], [fruit, popular]]|4   |[green, kiwi, fruit, popular]|
+----+---+-----------------------------------------------------------------+----+-----------------------------+

但现在我想使用 pandas udf。 我想首先groupby并获取function中的collect_bigramm列。因此保留dataframe中的所有列,但还要添加一个新列,即function中的lst_res数组。


schema2 = StructType(
    [
        StructField("year", IntegerType(), True),
        StructField("id", IntegerType(), True),
        StructField("bigram", ArrayType(StringType(), True), True),
        StructField("new_col", ArrayType(StringType(), True), True),
        StructField("collect_bigramm", ArrayType(ArrayType(StringType(), True), True), True),
    ]
)


@pandas_udf(schema2, functionType=PandasUDFType.GROUPED_MAP)
def myfunc(df):

    graph = DiGraph()
    for index, row in df.iterrows():
        # Instead of the variable lst, i need to insert a column sdf['collect_bigramm']
        ...

    return df


sdf_new = sdf.groupby(["year", "id"]).apply(myfunc)
  1. 不想运行groupBy两次(一次用于sdf1 ,一次用于pandas_udf ),它只会扼杀 pandas_udf 的“对记录列表进行分组,然后对其进行矢量化,然后发送给工作人员”的pandas_udf 你想做这样的事情而不是sdf1.groupby("year", "id").applyInPandas(myfunc, schema2)

  2. 你的 UDF 现在是一个“熊猫 UDF”,它实际上只是一个 Python function,取一个 Pandas DF 并返回另一个 Pandas UDF。 有了这个意思,您甚至可以在没有Spark 的情况下运行 function。 这里的诀窍就是如何形成你的 dataframe 来满足你的需要。 检查下面的运行代码,我保留了你的大部分 .networkx 代码,只是从输入和 output 中修复了一点。

def myfunc(pdf):
    pdf = (pdf
        .groupby(['year', 'id'])['bigram']
        .agg(list=list, len=len) # you might want to fix the list here to set
        .reset_index()
        .rename(columns={
            'list': 'collect_bigram',
            'len': 'size',
        })
    )

    graph = DiGraph()
    if pdf['size'][0] > 1:
        lst = pdf['collect_bigram'][0]
        for i, lst1 in enumerate(lst):
        ... # same as original code
        if lst_res == []:
            lst_res = None
        pdf['new_col'] = [lst_res]
    else:
        pdf['new_col'] = None
    return pdf

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM