簡體   English   中英

PySpark:用於 scipy 統計轉換的 Pandas UDF

[英]PySpark: Pandas UDF for scipy statistical transformations

我正在嘗試在 Spark 數據幀上創建一列 x 列的標准化(z 分數)列,但由於沒有一個工作正常而缺少一些東西。

這是我的例子:

import pandas as pd
from pyspark.sql.functions import pandas_udf, PandasUDFType
from scipy.stats import zscore

@pandas_udf('float')
def zscore_udf(x: pd.Series) -> pd.Series:
    return zscore(x)

spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()

columns = ["id","x"]
data = [("a", 81.0),
    ("b", 36.2),
    ("c", 12.0),
    ("d", 81.0),
    ("e", 36.3),
    ("f", 12.0),
    ("g", 111.7)]

df = spark.createDataFrame(data=data,schema=columns)

df.show()

df = df.withColumn('y', zscore_udf(df.x))

df.show()

這導致明顯錯誤的計算:

+---+-----+----+
| id|    x|   y|
+---+-----+----+
|  a| 81.0|null|
|  b| 36.2| 1.0|
|  c| 12.0|-1.0|
|  d| 81.0| 1.0|
|  e| 36.3|-1.0|
|  f| 12.0|-1.0|
|  g|111.7| 1.0|
+---+-----+----+

謝謝您的幫助。

怎么修:
而不是使用 UDF 計算stddev_pop和數據幀的avg並手動計算 z-score。
我建議第一步在整個數據幀上使用“窗口函數”,然后使用簡單的算術來獲得 z 分數。
查看建議的代碼:

from pyspark.sql.functions import avg, col, stddev_pop
from pyspark.sql.window import Window

df2 = df \
.select(
    "*",
    avg("x").over(Window.partitionBy()).alias("avg_x"),
    stddev_pop("x").over(Window.partitionBy()).alias("stddev_x"),
) \
.withColumn("manual_z_score", (col("x") - col("avg_x")) / col("stddev_x")) 

為什么 UDF 不起作用?
Spark用於分布式計算。 當您在 DataFrame 上執行操作時,Spark 會將工作負載分配到可用的 executors/workers 上的分區中。

pandas_udf也不例外。 當從 pd.Series -> pd.Series 類型運行 UDF 時,一些行被發送到分區 X,一些行被發送到分區 Y,然后當zscore運行時,它計算分區中數據的平均值和標准差,並寫入基於 zscore僅在該數據上。

我將使用spark_partition_id來“證明”這一點。
a、b、c 行映射到分區 0 中,而 d、e、f、g 位於分區 1 中。我手動計算了整個數據集和分區數據的均值/stddev_pop,然后計算了 z 分數。 UDF z-score 等於分區的 z-score。

from pyspark.sql.functions import pandas_udf, spark_partition_id, avg, stddev, col, stddev_pop
from pyspark.sql.window import Window

df2 = df \
.select(
    "*",
    zscore_udf(df.x).alias("z_score"),
    spark_partition_id().alias("partition"),
    avg("x").over(Window.partitionBy(spark_partition_id())).alias("avg_partition_x"),
    stddev_pop("x").over(Window.partitionBy(spark_partition_id())).alias("stddev_partition_x"),
) \
.withColumn("partition_z_score", (col("x") - col("avg_partition_x")) / col("stddev_partition_x"))

df2.show()

+---+-----+-----------+---------+-----------------+------------------+--------------------+
| id|    x|    z_score|partition|  avg_partition_x|stddev_partition_x|   partition_z_score|
+---+-----+-----------+---------+-----------------+------------------+--------------------+
|  a| 81.0|   1.327058|        0|43.06666666666666|28.584533502500186|  1.3270579815484989|
|  b| 36.2|-0.24022315|        0|43.06666666666666|28.584533502500186|-0.24022314955974558|
|  c| 12.0| -1.0868348|        0|43.06666666666666|28.584533502500186| -1.0868348319887526|
|  d| 81.0|  0.5366879|        1|            60.25|38.663063768925504|  0.5366879387524718|
|  e| 36.3|-0.61945426|        1|            60.25|38.663063768925504| -0.6194542714757446|
|  f| 12.0| -1.2479612|        1|            60.25|38.663063768925504|  -1.247961110593097|
|  g|111.7|  1.3307275|        1|            60.25|38.663063768925504|  1.3307274433163698|
+---+-----+-----------+---------+-----------------+------------------+--------------------+

我還在計算之前添加了 df.repartition(8) 並設法獲得與原始問題相似的結果。 具有 0 stddev --> null z 分數的分區,具有 2 行的分區 --> (-1, 1) z 分數。

+---+-----+-------+---------+---------------+------------------+-----------------+
| id|    x|z_score|partition|avg_partition_x|stddev_partition_x|partition_z_score|
+---+-----+-------+---------+---------------+------------------+-----------------+
|  a| 81.0|   null|        0|           81.0|               0.0|             null|
|  d| 81.0|   null|        0|           81.0|               0.0|             null|
|  f| 12.0|   null|        1|           12.0|               0.0|             null|
|  b| 36.2|   -1.0|        6|          73.95|             37.75|             -1.0|
|  g|111.7|    1.0|        6|          73.95|             37.75|              1.0|
|  c| 12.0|   -1.0|        7|          24.15|12.149999999999999|             -1.0|
|  e| 36.3|    1.0|        7|          24.15|12.149999999999999|              1.0|
+---+-----+-------+---------+---------------+------------------+-----------------+

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM