簡體   English   中英

PySpark:如何將 UDF 應用於多個列以創建多個新列?

[英]PySpark: How to apply UDF to multiple columns to create multiple new columns?

我有一個包含多列的數據幀,我想用作函數的輸入,該函數將每行產生多個輸出,每個輸出進入一個新列。

例如,我有一個函數可以獲取地址值並解析為更細粒度的部分:

def parser(address1: str, city: str, state: str) -> Dict[str, str]: 
    ...

示例輸出:

{'STREETNUMPREFIX': None,
 'STREETNUMBER': '123',
 'STREETNUMSUFFIX': None,
 'STREETNAME': 'Elm',
 'STREETTYPE': 'Ave.'}

因此,假設我有一個包含address1citystate列的 DataFrame ,我想使用這三列的值作為輸入將上述parser函數應用於所有行,並將每行的輸出存儲為新列與返回的字典匹配。

這是我迄今為止嘗試過的時間:

from typing import Dict

from pyspark.sql import functions as F
from pyspark.sql.types import Row, StringType, StructField, StructType
import usaddress

def parser(address1: str, city: str, state: str) -> Dict[str, str]:
    unstructured_address = " ".join((address1, city, state))
    return parse_unstructured_address(unstructured_address)
  
  
def parse_unstructured_address(address: str) -> Dict[str, str]:
  
    tags = usaddress.tag(address_string=address)
    return {
      "STREETNUMPREFIX": tags[0].get("AddressNumberPrefix", None),
      "STREETNUMBER": tags[0].get("AddressNumber", None),
      "STREETNUMUNIT": tags[0].get("OccupancyIdentifier", None),
      "STREETNUMSUFFIX": tags[0].get("AddressNumberSuffix", None),
      "PREDIRECTIONAL": tags[0].get("StreetNamePreDirectional", None),
      "STREETNAME": tags[0].get("StreetName", None),
      "STREETTYPE": tags[0].get("StreetNamePostType", None),
      "POSTDIRECTIONAL": tags[0].get("StreetNamePostDirectional", None),
    }

def parse_func(address: str, city: str, state: str) -> Row:
    address_parts = parser(address1=address, city=city, state=state)
    return Row(*address_parts.keys())(*address_parts.values())

def get_schema(columns: List[str]) ->StructType:
    return StructType([StructField(col_name, StringType(), False) for col_name in columns])

input_columns = ["Address1", "CITY", "STATE"]
df = spark.createDataFrame([("123 Main St.", "Cleveland", "OH"), ("57 Heinz St.", "Columbus", "OH")], input_columns)

parsed_columns = ["STREETNUMPREFIX", "STREETNUMBER", "STREETNUMSUFFIX", "STREETNAME", "STREETTYPE"]
out_columns = input_columns + parsed_columns
output_schema = get_schema(out_columns)

parse_udf = F.udf(parse_func, output_schema)

df = df.withColumn("Output", F.explode(F.array(parse_udf(df["Address1"], df["CITY"], df["STATE"]))))
display(df)

以上只導致奇怪的空指針異常,它告訴我如何解決問題:

SparkException: Job aborted due to stage failure: Task 0 in stage 9.0 failed 4 times, most recent failure: Lost task 0.3 in stage 9.0 (TID 86, 172.18.237.92, executor 1): java.lang.NullPointerException
    at org.apache.spark.sql.catalyst.expressions.codegen.UnsafeWriter.write(UnsafeWriter.java:110)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.writeFields_0_1$(Unknown Source)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
    at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$11(EvalPythonExec.scala:134)
    at scala.collection.Iterator$$anon$10.next(Iterator.scala:459)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:731)
    at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:80)
    at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$1(Collector.scala:187)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
    at org.apache.spark.scheduler.Task.doRunTask(Task.scala:144)
    at org.apache.spark.scheduler.Task.run(Task.scala:117)
    at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$9(Executor.scala:657)
    at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1581)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:660)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)

這是我使用 udf 的非常簡單的示例。

from pyspark.sql.functions import *
from pyspark.sql.types import *

def cal(a: int, b: int) -> [int, int]:
    return [a+b, a*b]

cal = udf(cal, ArrayType(StringType()))

df.select('A', 'B', *[cal('A', 'B')[i] for i in range(0, 2)]) \
  .toDF('A', 'B', 'Add', 'Muptiple').show()

+---+---+---+--------+
|  A|  B|Add|Muptiple|
+---+---+---+--------+
|  1|  2|  3|       2|
|  2|  4|  6|       8|
|  3|  6|  9|      18|
+---+---+---+--------+

我檢查了你的代碼並找到了這個。

def get_schema(columns: [str]) -> StructType:
    return StructType([StructField(col_name, StringType(), False) for col_name in columns])

您不允許所有列都使用null值,但確實存在。 所以錯誤來了。 我建議您更改可空值的False -> True ,然后它將起作用。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM