簡體   English   中英

如何在沒有 UDF 的情況下計算 PySpark 數據幀中數組列中的尾隨零

[英]How to count the trailing zeroes in an array column in a PySpark dataframe without a UDF

我有一個數據框,其中有一列具有固定數量整數的數組。 如何將包含數組中尾隨零數的列添加到 df 中? 我想避免使用 UDF 以獲得更好的性能。

例如,輸入 df:

>>> df.show()
+------------+
|           A|
+------------+
| [1,0,1,0,0]|
| [2,3,4,5,6]|
| [0,0,0,0,0]|
| [1,2,3,4,0]|
+------------+

和一個想要的輸出:

>>> trailing_zeroes(df).show()
+------------+-----------------+
|           A|   trailingZeroes|
+------------+-----------------+
| [1,0,1,0,0]|                2|
| [2,3,4,5,6]|                0|
| [0,0,0,0,0]|                5|
| [1,2,3,4,0]|                1|
+------------+-----------------+

將數組轉換為字符串時,有幾種新方法可以獲得結果:

>>> from pyspark.sql.functions import length, regexp_extract, array_join, reverse
>>> 
>>> df = spark.createDataFrame([(1, [1, 2, 3]),
...                             (2, [2, 0]),
...                             (3, [0, 2, 3, 10]),
...                             (4, [0, 2, 3, 10, 0]),
...                             (5, [0, 1, 0, 0, 0]),
...                             (6, [0, 0, 0]),
...                             (7, [0, ]),
...                             (8, [10, ]),
...                             (9, [100, ]),
...                             (10, [0, 100, ]),
...                             (11, [])],
...                            schema=("id", "arr"))
>>> 
>>> 
>>> df.withColumn("trailing_zero_count",
...               length(regexp_extract(array_join(reverse(df.arr), ""), "^(0+)", 0))
...               ).show()
+---+----------------+-------------------+
| id|             arr|trailing_zero_count|
+---+----------------+-------------------+
|  1|       [1, 2, 3]|                  0|
|  2|          [2, 0]|                  1|
|  3|   [0, 2, 3, 10]|                  0|
|  4|[0, 2, 3, 10, 0]|                  1|
|  5| [0, 1, 0, 0, 0]|                  3|
|  6|       [0, 0, 0]|                  3|
|  7|             [0]|                  1|
|  8|            [10]|                  0|
|  9|           [100]|                  0|
| 10|        [0, 100]|                  0|
| 11|              []|                  0|
+---+----------------+-------------------+

從 Spark 2.4 開始,您可以使用高階函數AGGREGATE來做到這一點:

from pyspark.sql.functions import reverse

(
  df.withColumn("arr_rev", reverse("A"))
  .selectExpr(
    "arr_rev", 
    "AGGREGATE(arr_rev, (1 AS p, CAST(0 AS LONG) AS sum), (buffer, value) -> (if(value != 0, 0, buffer.p), if(value=0, buffer.sum + buffer.p, buffer.sum)), buffer -> buffer.sum) AS result"
  )
)

假設A是您的數字數組。 這里只是要小心數據類型。 假設數組內的數字也是 long,我將初始值轉換為LONG

對於 Spark 2.4+,您絕對應該使用aggregate ,如@David Vrba的已接受答案所示

對於較舊的模型,這是正則表達式方法的替代方法。

首先創建一些示例數據:

import numpy as np
NROWS = 10
ARRAY_LENGTH = 5
np.random.seed(0)
data = [
    (np.random.randint(0, 100, x).tolist() + [0]*(ARRAY_LENGTH-x),) 
    for x in np.random.randint(0, ARRAY_LENGTH+1, NROWS)
]

df = spark.createDataFrame(data, ["myArray"])
df.show()
#+--------------------+
#|             myArray|
#+--------------------+
#| [36, 87, 70, 88, 0]|
#|[88, 12, 58, 65, 39]|
#|     [0, 0, 0, 0, 0]|
#|  [87, 46, 88, 0, 0]|
#|  [81, 37, 25, 0, 0]|
#|   [77, 72, 9, 0, 0]|
#|    [20, 0, 0, 0, 0]|
#|  [80, 69, 79, 0, 0]|
#|[47, 64, 82, 99, 88]|
#|   [49, 29, 0, 0, 0]|
#+--------------------+

現在反向迭代您的列,如果列是0則返回null否則返回ARRAY_LENGTH-(index+1) 合並此結果,這將返回第一個非空索引的值 - 與尾隨 0 的數量相同。

from pyspark.sql.functions import coalesce, col, when, lit, 
df.withColumn(
    "trailingZeroes",
    coalesce(
        *[
            when(col('myArray').getItem(index) != 0, lit(ARRAY_LENGTH-(index+1)))
            for index in range(ARRAY_LENGTH-1, -1, -1)
        ] + [lit(ARRAY_LENGTH)]
    )
).show()
#+--------------------+--------------+
#|             myArray|trailingZeroes|
#+--------------------+--------------+
#| [36, 87, 70, 88, 0]|             1|
#|[88, 12, 58, 65, 39]|             0|
#|     [0, 0, 0, 0, 0]|             5|
#|  [87, 46, 88, 0, 0]|             2|
#|  [81, 37, 25, 0, 0]|             2|
#|   [77, 72, 9, 0, 0]|             2|
#|    [20, 0, 0, 0, 0]|             4|
#|  [80, 69, 79, 0, 0]|             2|
#|[47, 64, 82, 99, 88]|             0|
#|   [49, 29, 0, 0, 0]|             3|
#+--------------------+--------------+

另一種自 Spark 1.5.0 起有效的解決方案。 這里我們使用trimrtrimregexp_replacelength來計算尾隨零:

from pyspark.sql.functions import expr

to_string_expr = expr("regexp_replace(trim('[]', string(A)), ', ', '')")

df.withColumn("str_ar", to_string_expr) \
  .withColumn("trailingZeroes", expr("length(str_ar) - length(rtrim('0', str_ar))"))

# +---------------+--------------+
# |              A|trailingZeroes|
# +---------------+--------------+
# |[1, 0, 1, 0, 0]|             2|
# |[2, 3, 4, 5, 6]|             0|
# |[0, 0, 0, 0, 0]|             5|
# |[1, 2, 3, 4, 0]|             1|
# +---------------+--------------+

分析:

expr最內到最外層元素開始:

  1. string(A)將數組轉換為其字符串表示,即[1, 0, 1, 0, 0]

  2. trim('[]', string(A))刪除前導[和尾隨]1, 0, 1, 0, 0

  3. regexp_replace(trim('[]', string(A)), ', ', '')中移除了,項目之間以形成即,最終字符串表示10100

  4. rtrim('0',regexp_replace(trim('[]', string(A)), ', ', ''))修剪尾隨零,即: 101

  5. 最后,我們得到完整字符串和修剪后的字符串的長度,然后減去它們,這將給我們零尾隨長度。

更新

使用下一個代碼,您可以填充一些數據(從@pault 的帖子中借用)並使用timeit測量大型數據集的執行時間。

下面我為三個已發布的方法添加了一些基准測試。 從結果中我們可以得出結論,這些方法的性能存在一些趨勢:

from pyspark.sql.functions import expr, regexp_replace, regexp_extract, reverse, length, array_join
import numpy as np
import timeit

NROWS = 1000000
ARRAY_LENGTH = 5
np.random.seed(0)
data = [
    (np.random.randint(0, 9, x).tolist() + [0]*(ARRAY_LENGTH-x),) 
    for x in np.random.randint(0, ARRAY_LENGTH+1, NROWS)
]

df = spark.createDataFrame(data, ["A"])

def trim_func():
  to_string_expr = expr("regexp_replace(trim('[]', string(A)), ', ', '')")
  df.withColumn("str_ar", to_string_expr) \
    .withColumn("trailingZeroes", expr("length(str_ar) - length(rtrim('0', str_ar))")) \
    .show()
# Avg: 0.11089507223994588

def aggr_func():
  df.withColumn("arr_rev", reverse("A")) \
    .selectExpr("arr_rev", "AGGREGATE(arr_rev, (1 AS p, CAST(0 AS LONG) AS sum), \
                (buffer, value) -> (if(value != 0, 0, buffer.p), \
                if(value=0, buffer.sum + buffer.p, buffer.sum)), \
                buffer -> buffer.sum) AS result") \
   .show()
# Avg: 0.16555462517004343

def join_func():
  df.withColumn("trailing_zero_count", \
                length( \
                  regexp_extract( \
                    array_join(reverse(df["A"]), ""), "^(0+)", 0))) \
  .show()
# Avg:0.11372986907997984

rounds = 100
algs = {"trim_func" : trim_func, "aggr_func" : aggr_func, "join_func" : join_func}
report = list()
for k in algs:
  elapsed_time = timeit.timeit(algs[k], number=rounds) / rounds
  report.append((k, elapsed_time))

report_df = spark.createDataFrame(report, ["alg", "avg_time"]).orderBy("avg_time")
display(report_df)

在此處輸入圖片說明

結果表明,對於 1000000 行和 100 次執行的數據集,基於字符串的處理 (trim_func, join_func) 方法的平均執行時間降低25%-30%

在不確定確切原因的情況下,我可以假設額外的處理時間來自聚合函數本身的復雜性。 無論如何,似乎性能差異是相當可觀的。

測試在 databricks 社區版集群/筆記本下執行。

對數組中的每個項目使用getItem()和 double when().otherwise() ,reduce 向后迭代數組,構建一個零的負計數器。 當遇到第一個非零值時,計數器變為正數並停止計數。 reduce 以 -1 的偽計數開始計數器,最后將其刪除。

import pyspark.sql.functions as F
from functools import reduce

cols = [F.col('myArray').getItem(index) for index in range(ARRAY_LENGTH-1, -1, -1)]
trailing_count_column = F.abs(reduce(lambda col1, col2: F.when((col1 < 0) & (col2 != 0), -col1).othewise(
    F.when((col1 < 0) & (col2 == 0), col1 - 1).otherwise(col1)), cols, F.lit(-1))) - 1

df = df.withColumn('trailingZeroes', trailing_count_column)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM