简体   繁体   English

将 UDF 重写为 pandas UDF Pyspark

[英]Rewrite UDF to pandas UDF Pyspark

I have a dataframe:我有一个 dataframe:

import pyspark.sql.functions as F

sdf1 = spark.createDataFrame(
    [
        (2022, 1, ["apple", "edible"]),
        (2022, 1, ["edible", "fruit"]),
        (2022, 1, ["orange", "sweet"]),
        (2022, 4, ["flowering ", "plant"]),
        (2022, 3, ["green", "kiwi"]),
        (2022, 3, ["kiwi", "fruit"]),
        (2022, 3, ["fruit", "popular"]),
        (2022, 3, ["yellow", "lemon"]),
    ],
    [
        "year",
        "id",
        "bigram",
    ],
)
sdf1.show(truncate=False)

    +----+---+-------------------+
    |year|id |bigram             |
    +----+---+-------------------+
    |2022|1  |[apple, edible]    |
    |2022|1  |[edible, fruit]    |
    |2022|1  |[orange, sweet]    |
    |2022|4  |[flowering , plant]|
    |2022|3  |[green, kiwi]      |
    |2022|3  |[kiwi, fruit]      |
    |2022|3  |[fruit, popular]   |
    |2022|3  |[yellow, lemon]    |
    +----+---+-------------------+

And i wrote a function that returns bigrams with the same last words in n-grams.I apply this function separately to the column.我写了一个 function,它返回 n-grams 中最后一个词相同的二元组。我将这个 function 分别应用于该列。

from networkx import DiGraph, dfs_labeled_edges

# Grouping
sdf = (
    sdf1.groupby("year", "id")
    .agg(F.collect_set("bigram").alias("collect_bigramm"))
    .withColumn("size", F.size("collect_bigramm"))
)

data_collect = sdf.collect()


@udf(returnType=ArrayType(StringType()))
def myfunc(lst):
    graph = DiGraph()

    for row in data_collect:
        if row["size"] > 1:
            for i, lst1 in enumerate(lst):
                while i < len(lst) - 1:
                    lst2 = lst[i + 1]
                    if lst1[0] == lst2[1]:
                        graph.add_edge(lst2[0], lst2[1])
                        graph.add_edge(lst1[0], lst1[1])
                    elif lst1[1] == lst2[0]:
                        graph.add_edge(lst1[0], lst1[1])
                        graph.add_edge(lst2[0], lst2[1])
                    i = i + 1

            gen = dfs_labeled_edges(graph)
            lst_tmp = []
            lst_res = []
            f = 0
            for g in list(gen):
                if (g[2] == "forward") and (g[0] != g[1]):
                    f = 1
                    lst_tmp.append(g[0])
                    lst_tmp.append(g[1])

                if g[2] == "nontree":
                    continue
                if g[2] == "reverse":
                    if f == 1:
                        lst_res.append(lst_tmp.copy())
                    f = 0
                    if g[0] in lst_tmp:
                        lst_tmp.remove(g[0])
                    if g[1] in lst_tmp:
                        lst_tmp.remove(g[1])

            if lst_res != []:
                lst_res = [
                    ii for n, ii in enumerate(lst_res[0]) if ii not in lst_res[0][:n]
                ]
            if lst_res == []:
                lst_res = None
            return lst_res


sdf_new = sdf.withColumn("new_col", myfunc(F.col("collect_bigramm")))
sdf_new.show(truncate=False)

Output: Output:

+----+---+-----------------------------------------------------------------+----+-----------------------------+
|year|id |collect_bigramm                                                          |size|new_col                      |
+----+---+-----------------------------------------------------------------+----+-----------------------------+
|2022|4  |[[flowering , plant]]                                            |1   |null                         |
|2022|1  |[[edible, fruit], [orange, sweet], [apple, edible]]              |3   |[apple, edible, fruit]       |
|2022|3  |[[yellow, lemon], [green, kiwi], [kiwi, fruit], [fruit, popular]]|4   |[green, kiwi, fruit, popular]|
+----+---+-----------------------------------------------------------------+----+-----------------------------+

But now i want to use the pandas udf.但现在我想使用 pandas udf。 I would like to first groupby and get the collect_bigramm column in the function. And thus leave all the columns in the dataframe, but also add a new one, which is the lst_res array in the function.我想首先groupby并获取function中的collect_bigramm列。因此保留dataframe中的所有列,但还要添加一个新列,即function中的lst_res数组。


schema2 = StructType(
    [
        StructField("year", IntegerType(), True),
        StructField("id", IntegerType(), True),
        StructField("bigram", ArrayType(StringType(), True), True),
        StructField("new_col", ArrayType(StringType(), True), True),
        StructField("collect_bigramm", ArrayType(ArrayType(StringType(), True), True), True),
    ]
)


@pandas_udf(schema2, functionType=PandasUDFType.GROUPED_MAP)
def myfunc(df):

    graph = DiGraph()
    for index, row in df.iterrows():
        # Instead of the variable lst, i need to insert a column sdf['collect_bigramm']
        ...

    return df


sdf_new = sdf.groupby(["year", "id"]).apply(myfunc)
  1. You don't want to run groupBy twice (one for sdf1 and one for pandas_udf ), it'd simply kill the idea of "grouping a list of records then vectorize it then send to worker" of pandas_udf .不想运行groupBy两次(一次用于sdf1 ,一次用于pandas_udf ),它只会扼杀 pandas_udf 的“对记录列表进行分组,然后对其进行矢量化,然后发送给工作人员”的pandas_udf You'd want to do something like this instead sdf1.groupby("year", "id").applyInPandas(myfunc, schema2)你想做这样的事情而不是sdf1.groupby("year", "id").applyInPandas(myfunc, schema2)

  2. Your UDF is now a "Panda UDF", which is literally just a Python function, take one Pandas DF and return another Pandas UDF.你的 UDF 现在是一个“熊猫 UDF”,它实际上只是一个 Python function,取一个 Pandas DF 并返回另一个 Pandas UDF。 With that meaning, you can even run that function without Spark.有了这个意思,您甚至可以在没有Spark 的情况下运行 function。 The trick here is just how to form your dataframe to feed with what you need.这里的诀窍就是如何形成你的 dataframe 来满足你的需要。 Check the running code below, I kept most of your.networkx code, just fix a little from the input and output.检查下面的运行代码,我保留了你的大部分 .networkx 代码,只是从输入和 output 中修复了一点。

def myfunc(pdf):
    pdf = (pdf
        .groupby(['year', 'id'])['bigram']
        .agg(list=list, len=len) # you might want to fix the list here to set
        .reset_index()
        .rename(columns={
            'list': 'collect_bigram',
            'len': 'size',
        })
    )

    graph = DiGraph()
    if pdf['size'][0] > 1:
        lst = pdf['collect_bigram'][0]
        for i, lst1 in enumerate(lst):
        ... # same as original code
        if lst_res == []:
            lst_res = None
        pdf['new_col'] = [lst_res]
    else:
        pdf['new_col'] = None
    return pdf

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM