简体   繁体   中英

How to count the trailing zeroes in an array column in a PySpark dataframe without a UDF

I have a Dataframe with a column of an array with a fixed amount of integers. How can I add to the df a column that contains the number of trailing zeroes in the array? I would like to avoid using a UDF for better performance.

For example, an input df:

>>> df.show()
+------------+
|           A|
+------------+
| [1,0,1,0,0]|
| [2,3,4,5,6]|
| [0,0,0,0,0]|
| [1,2,3,4,0]|
+------------+

And a wanted output:

>>> trailing_zeroes(df).show()
+------------+-----------------+
|           A|   trailingZeroes|
+------------+-----------------+
| [1,0,1,0,0]|                2|
| [2,3,4,5,6]|                0|
| [0,0,0,0,0]|                5|
| [1,2,3,4,0]|                1|
+------------+-----------------+

When you convert the array to a string, there are several new ways to get to the result:

>>> from pyspark.sql.functions import length, regexp_extract, array_join, reverse
>>> 
>>> df = spark.createDataFrame([(1, [1, 2, 3]),
...                             (2, [2, 0]),
...                             (3, [0, 2, 3, 10]),
...                             (4, [0, 2, 3, 10, 0]),
...                             (5, [0, 1, 0, 0, 0]),
...                             (6, [0, 0, 0]),
...                             (7, [0, ]),
...                             (8, [10, ]),
...                             (9, [100, ]),
...                             (10, [0, 100, ]),
...                             (11, [])],
...                            schema=("id", "arr"))
>>> 
>>> 
>>> df.withColumn("trailing_zero_count",
...               length(regexp_extract(array_join(reverse(df.arr), ""), "^(0+)", 0))
...               ).show()
+---+----------------+-------------------+
| id|             arr|trailing_zero_count|
+---+----------------+-------------------+
|  1|       [1, 2, 3]|                  0|
|  2|          [2, 0]|                  1|
|  3|   [0, 2, 3, 10]|                  0|
|  4|[0, 2, 3, 10, 0]|                  1|
|  5| [0, 1, 0, 0, 0]|                  3|
|  6|       [0, 0, 0]|                  3|
|  7|             [0]|                  1|
|  8|            [10]|                  0|
|  9|           [100]|                  0|
| 10|        [0, 100]|                  0|
| 11|              []|                  0|
+---+----------------+-------------------+

Since Spark 2.4 you can use Higher Order Function AGGREGATE to do that:

from pyspark.sql.functions import reverse

(
  df.withColumn("arr_rev", reverse("A"))
  .selectExpr(
    "arr_rev", 
    "AGGREGATE(arr_rev, (1 AS p, CAST(0 AS LONG) AS sum), (buffer, value) -> (if(value != 0, 0, buffer.p), if(value=0, buffer.sum + buffer.p, buffer.sum)), buffer -> buffer.sum) AS result"
  )
)

assuming A is your array with numbers. Here just be careful with data types. I am casting the initial value to LONG assuming the numbers inside the array are also longs.

For Spark 2.4+, you should absolutely use aggregate as shown in @David Vrba 's accepted answer .

For older models, here's an alternative to the regular expression approach.

First create some sample data:

import numpy as np
NROWS = 10
ARRAY_LENGTH = 5
np.random.seed(0)
data = [
    (np.random.randint(0, 100, x).tolist() + [0]*(ARRAY_LENGTH-x),) 
    for x in np.random.randint(0, ARRAY_LENGTH+1, NROWS)
]

df = spark.createDataFrame(data, ["myArray"])
df.show()
#+--------------------+
#|             myArray|
#+--------------------+
#| [36, 87, 70, 88, 0]|
#|[88, 12, 58, 65, 39]|
#|     [0, 0, 0, 0, 0]|
#|  [87, 46, 88, 0, 0]|
#|  [81, 37, 25, 0, 0]|
#|   [77, 72, 9, 0, 0]|
#|    [20, 0, 0, 0, 0]|
#|  [80, 69, 79, 0, 0]|
#|[47, 64, 82, 99, 88]|
#|   [49, 29, 0, 0, 0]|
#+--------------------+

Now iterate through your columns in reverse and return null if the column is 0 , or the ARRAY_LENGTH-(index+1) otherwise. Coalesce the results of this, which will return the value from the first non-null index - the same as the number of trailing 0's.

from pyspark.sql.functions import coalesce, col, when, lit, 
df.withColumn(
    "trailingZeroes",
    coalesce(
        *[
            when(col('myArray').getItem(index) != 0, lit(ARRAY_LENGTH-(index+1)))
            for index in range(ARRAY_LENGTH-1, -1, -1)
        ] + [lit(ARRAY_LENGTH)]
    )
).show()
#+--------------------+--------------+
#|             myArray|trailingZeroes|
#+--------------------+--------------+
#| [36, 87, 70, 88, 0]|             1|
#|[88, 12, 58, 65, 39]|             0|
#|     [0, 0, 0, 0, 0]|             5|
#|  [87, 46, 88, 0, 0]|             2|
#|  [81, 37, 25, 0, 0]|             2|
#|   [77, 72, 9, 0, 0]|             2|
#|    [20, 0, 0, 0, 0]|             4|
#|  [80, 69, 79, 0, 0]|             2|
#|[47, 64, 82, 99, 88]|             0|
#|   [49, 29, 0, 0, 0]|             3|
#+--------------------+--------------+

One more solution that works since Spark 1.5.0. Here we use trim , rtrim , regexp_replace and length to count the trailing zeros:

from pyspark.sql.functions import expr

to_string_expr = expr("regexp_replace(trim('[]', string(A)), ', ', '')")

df.withColumn("str_ar", to_string_expr) \
  .withColumn("trailingZeroes", expr("length(str_ar) - length(rtrim('0', str_ar))"))

# +---------------+--------------+
# |              A|trailingZeroes|
# +---------------+--------------+
# |[1, 0, 1, 0, 0]|             2|
# |[2, 3, 4, 5, 6]|             0|
# |[0, 0, 0, 0, 0]|             5|
# |[1, 2, 3, 4, 0]|             1|
# +---------------+--------------+

Analysis:

Starting from the inner to the outer most elements of expr :

  1. string(A) converts array to its string representation ie [1, 0, 1, 0, 0] .

  2. trim('[]', string(A)) removes leading [ and trailing ] respectively ie 1, 0, 1, 0, 0 .

  3. regexp_replace(trim('[]', string(A)), ', ', '') removes , between items to form the final string representation ie 10100 .

  4. rtrim('0',regexp_replace(trim('[]', string(A)), ', ', '')) trims the trailing zeros ie: 101 .

  5. Finally we get the length of the complete string and the trimmed one and we subtract them, this will give us the zero trailing length.

UPDATE

With the next code you can populate some data (borrowed from @pault's post) and measure the execution time for a large dataset using timeit .

Below I added some benchmarking for three of the posted methods. From the results we can conclude that there are some trends regarding the performance of the methods:

from pyspark.sql.functions import expr, regexp_replace, regexp_extract, reverse, length, array_join
import numpy as np
import timeit

NROWS = 1000000
ARRAY_LENGTH = 5
np.random.seed(0)
data = [
    (np.random.randint(0, 9, x).tolist() + [0]*(ARRAY_LENGTH-x),) 
    for x in np.random.randint(0, ARRAY_LENGTH+1, NROWS)
]

df = spark.createDataFrame(data, ["A"])

def trim_func():
  to_string_expr = expr("regexp_replace(trim('[]', string(A)), ', ', '')")
  df.withColumn("str_ar", to_string_expr) \
    .withColumn("trailingZeroes", expr("length(str_ar) - length(rtrim('0', str_ar))")) \
    .show()
# Avg: 0.11089507223994588

def aggr_func():
  df.withColumn("arr_rev", reverse("A")) \
    .selectExpr("arr_rev", "AGGREGATE(arr_rev, (1 AS p, CAST(0 AS LONG) AS sum), \
                (buffer, value) -> (if(value != 0, 0, buffer.p), \
                if(value=0, buffer.sum + buffer.p, buffer.sum)), \
                buffer -> buffer.sum) AS result") \
   .show()
# Avg: 0.16555462517004343

def join_func():
  df.withColumn("trailing_zero_count", \
                length( \
                  regexp_extract( \
                    array_join(reverse(df["A"]), ""), "^(0+)", 0))) \
  .show()
# Avg:0.11372986907997984

rounds = 100
algs = {"trim_func" : trim_func, "aggr_func" : aggr_func, "join_func" : join_func}
report = list()
for k in algs:
  elapsed_time = timeit.timeit(algs[k], number=rounds) / rounds
  report.append((k, elapsed_time))

report_df = spark.createDataFrame(report, ["alg", "avg_time"]).orderBy("avg_time")
display(report_df)

在此处输入图片说明

The results showed that for a dataset of 1000000 rows and 100 executions the average execution time was is by 25%-30% lower for the string based processing (trim_func, join_func) methods .

Without being sure about the exact reason, I can assume that the extra process time occurs from the complexity of the aggregate function itself. In any case, it seems that performance difference it is considerable.

The test was executed under the databricks community edition cluster/notebook.

Using getItem() , and a double when().otherwise() for each item in the array, the reduce iterate over the array backwards build a negative counter of zeros. When the first non-zero value is encountered, the counter becomes positive and the counting stops. The reduce starts the counter with a pseudo count of -1, which is removed at the end.

import pyspark.sql.functions as F
from functools import reduce

cols = [F.col('myArray').getItem(index) for index in range(ARRAY_LENGTH-1, -1, -1)]
trailing_count_column = F.abs(reduce(lambda col1, col2: F.when((col1 < 0) & (col2 != 0), -col1).othewise(
    F.when((col1 < 0) & (col2 == 0), col1 - 1).otherwise(col1)), cols, F.lit(-1))) - 1

df = df.withColumn('trailingZeroes', trailing_count_column)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM