简体   繁体   English

PySpark:如何将 UDF 应用于多个列以创建多个新列?

[英]PySpark: How to apply UDF to multiple columns to create multiple new columns?

I have a DataFrame containing several columns I'd like to use as input to a function which will produce multiple outputs per row, with each output going into a new column.我有一个包含多列的数据帧,我想用作函数的输入,该函数将每行产生多个输出,每个输出进入一个新列。

For example, I have a function that takes address values and parses into finer grain parts:例如,我有一个函数可以获取地址值并解析为更细粒度的部分:

def parser(address1: str, city: str, state: str) -> Dict[str, str]: 
    ...

Example output:示例输出:

{'STREETNUMPREFIX': None,
 'STREETNUMBER': '123',
 'STREETNUMSUFFIX': None,
 'STREETNAME': 'Elm',
 'STREETTYPE': 'Ave.'}

So let's say I have a DataFrame with columns address1 , city , and state , and I would like to apply the above parser function across all rows using the value of these three columns as the input, and storing the output for each row as new columns matching to the dictionary returned.因此,假设我有一个包含address1citystate列的 DataFrame ,我想使用这三列的值作为输入将上述parser函数应用于所有行,并将每行的输出存储为新列与返回的字典匹配。

Here is when I have tried so far:这是我迄今为止尝试过的时间:

from typing import Dict

from pyspark.sql import functions as F
from pyspark.sql.types import Row, StringType, StructField, StructType
import usaddress

def parser(address1: str, city: str, state: str) -> Dict[str, str]:
    unstructured_address = " ".join((address1, city, state))
    return parse_unstructured_address(unstructured_address)
  
  
def parse_unstructured_address(address: str) -> Dict[str, str]:
  
    tags = usaddress.tag(address_string=address)
    return {
      "STREETNUMPREFIX": tags[0].get("AddressNumberPrefix", None),
      "STREETNUMBER": tags[0].get("AddressNumber", None),
      "STREETNUMUNIT": tags[0].get("OccupancyIdentifier", None),
      "STREETNUMSUFFIX": tags[0].get("AddressNumberSuffix", None),
      "PREDIRECTIONAL": tags[0].get("StreetNamePreDirectional", None),
      "STREETNAME": tags[0].get("StreetName", None),
      "STREETTYPE": tags[0].get("StreetNamePostType", None),
      "POSTDIRECTIONAL": tags[0].get("StreetNamePostDirectional", None),
    }

def parse_func(address: str, city: str, state: str) -> Row:
    address_parts = parser(address1=address, city=city, state=state)
    return Row(*address_parts.keys())(*address_parts.values())

def get_schema(columns: List[str]) ->StructType:
    return StructType([StructField(col_name, StringType(), False) for col_name in columns])

input_columns = ["Address1", "CITY", "STATE"]
df = spark.createDataFrame([("123 Main St.", "Cleveland", "OH"), ("57 Heinz St.", "Columbus", "OH")], input_columns)

parsed_columns = ["STREETNUMPREFIX", "STREETNUMBER", "STREETNUMSUFFIX", "STREETNAME", "STREETTYPE"]
out_columns = input_columns + parsed_columns
output_schema = get_schema(out_columns)

parse_udf = F.udf(parse_func, output_schema)

df = df.withColumn("Output", F.explode(F.array(parse_udf(df["Address1"], df["CITY"], df["STATE"]))))
display(df)

The above has only resulted in strange null pointer exceptions that tell me nothing about how to fix things:以上只导致奇怪的空指针异常,它告诉我如何解决问题:

SparkException: Job aborted due to stage failure: Task 0 in stage 9.0 failed 4 times, most recent failure: Lost task 0.3 in stage 9.0 (TID 86, 172.18.237.92, executor 1): java.lang.NullPointerException
    at org.apache.spark.sql.catalyst.expressions.codegen.UnsafeWriter.write(UnsafeWriter.java:110)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.writeFields_0_1$(Unknown Source)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
    at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$11(EvalPythonExec.scala:134)
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:459)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:731)
    at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:80)
    at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$1(Collector.scala:187)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.doRunTask(Task.scala:144)
    at org.apache.spark.scheduler.Task.run(Task.scala:117)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$9(Executor.scala:657)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1581)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:660)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

Here is my really simple example for the udf usage.这是我使用 udf 的非常简单的示例。

from pyspark.sql.functions import *
from pyspark.sql.types import *

def cal(a: int, b: int) -> [int, int]:
    return [a+b, a*b]

cal = udf(cal, ArrayType(StringType()))

df.select('A', 'B', *[cal('A', 'B')[i] for i in range(0, 2)]) \
  .toDF('A', 'B', 'Add', 'Muptiple').show()

+---+---+---+--------+
|  A|  B|Add|Muptiple|
+---+---+---+--------+
|  1|  2|  3|       2|
|  2|  4|  6|       8|
|  3|  6|  9|      18|
+---+---+---+--------+

I have checked your code and found this.我检查了你的代码并找到了这个。

def get_schema(columns: [str]) -> StructType:
    return StructType([StructField(col_name, StringType(), False) for col_name in columns])

You did not allow the null value for all columns but there is it.您不允许所有列都使用null值,但确实存在。 So the error comes.所以错误来了。 I'd recommend you to change False -> True of the nullable, then it will work.我建议您更改可空值的False -> True ,然后它将起作用。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM