[英]PySpark: How to apply UDF to multiple columns to create multiple new columns?
我有一個包含多列的數據幀,我想用作函數的輸入,該函數將每行產生多個輸出,每個輸出進入一個新列。
例如,我有一個函數可以獲取地址值並解析為更細粒度的部分:
def parser(address1: str, city: str, state: str) -> Dict[str, str]:
...
示例輸出:
{'STREETNUMPREFIX': None,
'STREETNUMBER': '123',
'STREETNUMSUFFIX': None,
'STREETNAME': 'Elm',
'STREETTYPE': 'Ave.'}
因此,假設我有一個包含address1
、 city
和state
列的 DataFrame ,我想使用這三列的值作為輸入將上述parser
函數應用於所有行,並將每行的輸出存儲為新列與返回的字典匹配。
這是我迄今為止嘗試過的時間:
from typing import Dict
from pyspark.sql import functions as F
from pyspark.sql.types import Row, StringType, StructField, StructType
import usaddress
def parser(address1: str, city: str, state: str) -> Dict[str, str]:
unstructured_address = " ".join((address1, city, state))
return parse_unstructured_address(unstructured_address)
def parse_unstructured_address(address: str) -> Dict[str, str]:
tags = usaddress.tag(address_string=address)
return {
"STREETNUMPREFIX": tags[0].get("AddressNumberPrefix", None),
"STREETNUMBER": tags[0].get("AddressNumber", None),
"STREETNUMUNIT": tags[0].get("OccupancyIdentifier", None),
"STREETNUMSUFFIX": tags[0].get("AddressNumberSuffix", None),
"PREDIRECTIONAL": tags[0].get("StreetNamePreDirectional", None),
"STREETNAME": tags[0].get("StreetName", None),
"STREETTYPE": tags[0].get("StreetNamePostType", None),
"POSTDIRECTIONAL": tags[0].get("StreetNamePostDirectional", None),
}
def parse_func(address: str, city: str, state: str) -> Row:
address_parts = parser(address1=address, city=city, state=state)
return Row(*address_parts.keys())(*address_parts.values())
def get_schema(columns: List[str]) ->StructType:
return StructType([StructField(col_name, StringType(), False) for col_name in columns])
input_columns = ["Address1", "CITY", "STATE"]
df = spark.createDataFrame([("123 Main St.", "Cleveland", "OH"), ("57 Heinz St.", "Columbus", "OH")], input_columns)
parsed_columns = ["STREETNUMPREFIX", "STREETNUMBER", "STREETNUMSUFFIX", "STREETNAME", "STREETTYPE"]
out_columns = input_columns + parsed_columns
output_schema = get_schema(out_columns)
parse_udf = F.udf(parse_func, output_schema)
df = df.withColumn("Output", F.explode(F.array(parse_udf(df["Address1"], df["CITY"], df["STATE"]))))
display(df)
以上只導致奇怪的空指針異常,它告訴我如何解決問題:
SparkException: Job aborted due to stage failure: Task 0 in stage 9.0 failed 4 times, most recent failure: Lost task 0.3 in stage 9.0 (TID 86, 172.18.237.92, executor 1): java.lang.NullPointerException
at org.apache.spark.sql.catalyst.expressions.codegen.UnsafeWriter.write(UnsafeWriter.java:110)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.writeFields_0_1$(Unknown Source)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$SpecificUnsafeProjection.apply(Unknown Source)
at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$11(EvalPythonExec.scala:134)
at scala.collection.Iterator$$anon$10.next(Iterator.scala:459)
at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage3.processNext(Unknown Source)
at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
at org.apache.spark.sql.execution.WholeStageCodegenExec$$anon$1.hasNext(WholeStageCodegenExec.scala:731)
at org.apache.spark.sql.execution.collect.UnsafeRowBatchUtils$.encodeUnsafeRows(UnsafeRowBatchUtils.scala:80)
at org.apache.spark.sql.execution.collect.Collector.$anonfun$processFunc$1(Collector.scala:187)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.doRunTask(Task.scala:144)
at org.apache.spark.scheduler.Task.run(Task.scala:117)
at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$9(Executor.scala:657)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1581)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:660)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
這是我使用 udf 的非常簡單的示例。
from pyspark.sql.functions import *
from pyspark.sql.types import *
def cal(a: int, b: int) -> [int, int]:
return [a+b, a*b]
cal = udf(cal, ArrayType(StringType()))
df.select('A', 'B', *[cal('A', 'B')[i] for i in range(0, 2)]) \
.toDF('A', 'B', 'Add', 'Muptiple').show()
+---+---+---+--------+
| A| B|Add|Muptiple|
+---+---+---+--------+
| 1| 2| 3| 2|
| 2| 4| 6| 8|
| 3| 6| 9| 18|
+---+---+---+--------+
我檢查了你的代碼並找到了這個。
def get_schema(columns: [str]) -> StructType:
return StructType([StructField(col_name, StringType(), False) for col_name in columns])
您不允許所有列都使用null
值,但確實存在。 所以錯誤來了。 我建議您更改可空值的False -> True
,然后它將起作用。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.