I have a Dataframe with a column of an array with a fixed amount of integers. How can I add to the df a column that contains the number of trailing zeroes in the array? I would like to avoid using a UDF for better performance.
For example, an input df:
>>> df.show()
+------------+
| A|
+------------+
| [1,0,1,0,0]|
| [2,3,4,5,6]|
| [0,0,0,0,0]|
| [1,2,3,4,0]|
+------------+
And a wanted output:
>>> trailing_zeroes(df).show()
+------------+-----------------+
| A| trailingZeroes|
+------------+-----------------+
| [1,0,1,0,0]| 2|
| [2,3,4,5,6]| 0|
| [0,0,0,0,0]| 5|
| [1,2,3,4,0]| 1|
+------------+-----------------+
When you convert the array to a string, there are several new ways to get to the result:
>>> from pyspark.sql.functions import length, regexp_extract, array_join, reverse
>>>
>>> df = spark.createDataFrame([(1, [1, 2, 3]),
... (2, [2, 0]),
... (3, [0, 2, 3, 10]),
... (4, [0, 2, 3, 10, 0]),
... (5, [0, 1, 0, 0, 0]),
... (6, [0, 0, 0]),
... (7, [0, ]),
... (8, [10, ]),
... (9, [100, ]),
... (10, [0, 100, ]),
... (11, [])],
... schema=("id", "arr"))
>>>
>>>
>>> df.withColumn("trailing_zero_count",
... length(regexp_extract(array_join(reverse(df.arr), ""), "^(0+)", 0))
... ).show()
+---+----------------+-------------------+
| id| arr|trailing_zero_count|
+---+----------------+-------------------+
| 1| [1, 2, 3]| 0|
| 2| [2, 0]| 1|
| 3| [0, 2, 3, 10]| 0|
| 4|[0, 2, 3, 10, 0]| 1|
| 5| [0, 1, 0, 0, 0]| 3|
| 6| [0, 0, 0]| 3|
| 7| [0]| 1|
| 8| [10]| 0|
| 9| [100]| 0|
| 10| [0, 100]| 0|
| 11| []| 0|
+---+----------------+-------------------+
Since Spark 2.4 you can use Higher Order Function AGGREGATE
to do that:
from pyspark.sql.functions import reverse
(
df.withColumn("arr_rev", reverse("A"))
.selectExpr(
"arr_rev",
"AGGREGATE(arr_rev, (1 AS p, CAST(0 AS LONG) AS sum), (buffer, value) -> (if(value != 0, 0, buffer.p), if(value=0, buffer.sum + buffer.p, buffer.sum)), buffer -> buffer.sum) AS result"
)
)
assuming A
is your array with numbers. Here just be careful with data types. I am casting the initial value to LONG
assuming the numbers inside the array are also longs.
For Spark 2.4+, you should absolutely use aggregate
as shown in @David Vrba 's accepted answer .
For older models, here's an alternative to the regular expression approach.
First create some sample data:
import numpy as np
NROWS = 10
ARRAY_LENGTH = 5
np.random.seed(0)
data = [
(np.random.randint(0, 100, x).tolist() + [0]*(ARRAY_LENGTH-x),)
for x in np.random.randint(0, ARRAY_LENGTH+1, NROWS)
]
df = spark.createDataFrame(data, ["myArray"])
df.show()
#+--------------------+
#| myArray|
#+--------------------+
#| [36, 87, 70, 88, 0]|
#|[88, 12, 58, 65, 39]|
#| [0, 0, 0, 0, 0]|
#| [87, 46, 88, 0, 0]|
#| [81, 37, 25, 0, 0]|
#| [77, 72, 9, 0, 0]|
#| [20, 0, 0, 0, 0]|
#| [80, 69, 79, 0, 0]|
#|[47, 64, 82, 99, 88]|
#| [49, 29, 0, 0, 0]|
#+--------------------+
Now iterate through your columns in reverse and return null
if the column is 0
, or the ARRAY_LENGTH-(index+1)
otherwise. Coalesce the results of this, which will return the value from the first non-null index - the same as the number of trailing 0's.
from pyspark.sql.functions import coalesce, col, when, lit,
df.withColumn(
"trailingZeroes",
coalesce(
*[
when(col('myArray').getItem(index) != 0, lit(ARRAY_LENGTH-(index+1)))
for index in range(ARRAY_LENGTH-1, -1, -1)
] + [lit(ARRAY_LENGTH)]
)
).show()
#+--------------------+--------------+
#| myArray|trailingZeroes|
#+--------------------+--------------+
#| [36, 87, 70, 88, 0]| 1|
#|[88, 12, 58, 65, 39]| 0|
#| [0, 0, 0, 0, 0]| 5|
#| [87, 46, 88, 0, 0]| 2|
#| [81, 37, 25, 0, 0]| 2|
#| [77, 72, 9, 0, 0]| 2|
#| [20, 0, 0, 0, 0]| 4|
#| [80, 69, 79, 0, 0]| 2|
#|[47, 64, 82, 99, 88]| 0|
#| [49, 29, 0, 0, 0]| 3|
#+--------------------+--------------+
One more solution that works since Spark 1.5.0. Here we use trim
, rtrim
, regexp_replace
and length
to count the trailing zeros:
from pyspark.sql.functions import expr
to_string_expr = expr("regexp_replace(trim('[]', string(A)), ', ', '')")
df.withColumn("str_ar", to_string_expr) \
.withColumn("trailingZeroes", expr("length(str_ar) - length(rtrim('0', str_ar))"))
# +---------------+--------------+
# | A|trailingZeroes|
# +---------------+--------------+
# |[1, 0, 1, 0, 0]| 2|
# |[2, 3, 4, 5, 6]| 0|
# |[0, 0, 0, 0, 0]| 5|
# |[1, 2, 3, 4, 0]| 1|
# +---------------+--------------+
Analysis:
Starting from the inner to the outer most elements of expr
:
string(A)
converts array to its string representation ie [1, 0, 1, 0, 0]
.
trim('[]', string(A))
removes leading [
and trailing ]
respectively ie 1, 0, 1, 0, 0
.
regexp_replace(trim('[]', string(A)), ', ', '')
removes ,
between items to form the final string representation ie 10100
.
rtrim('0',regexp_replace(trim('[]', string(A)), ', ', ''))
trims the trailing zeros ie: 101
.
Finally we get the length of the complete string and the trimmed one and we subtract them, this will give us the zero trailing length.
UPDATE
With the next code you can populate some data (borrowed from @pault's post) and measure the execution time for a large dataset using timeit
.
Below I added some benchmarking for three of the posted methods. From the results we can conclude that there are some trends regarding the performance of the methods:
from pyspark.sql.functions import expr, regexp_replace, regexp_extract, reverse, length, array_join
import numpy as np
import timeit
NROWS = 1000000
ARRAY_LENGTH = 5
np.random.seed(0)
data = [
(np.random.randint(0, 9, x).tolist() + [0]*(ARRAY_LENGTH-x),)
for x in np.random.randint(0, ARRAY_LENGTH+1, NROWS)
]
df = spark.createDataFrame(data, ["A"])
def trim_func():
to_string_expr = expr("regexp_replace(trim('[]', string(A)), ', ', '')")
df.withColumn("str_ar", to_string_expr) \
.withColumn("trailingZeroes", expr("length(str_ar) - length(rtrim('0', str_ar))")) \
.show()
# Avg: 0.11089507223994588
def aggr_func():
df.withColumn("arr_rev", reverse("A")) \
.selectExpr("arr_rev", "AGGREGATE(arr_rev, (1 AS p, CAST(0 AS LONG) AS sum), \
(buffer, value) -> (if(value != 0, 0, buffer.p), \
if(value=0, buffer.sum + buffer.p, buffer.sum)), \
buffer -> buffer.sum) AS result") \
.show()
# Avg: 0.16555462517004343
def join_func():
df.withColumn("trailing_zero_count", \
length( \
regexp_extract( \
array_join(reverse(df["A"]), ""), "^(0+)", 0))) \
.show()
# Avg:0.11372986907997984
rounds = 100
algs = {"trim_func" : trim_func, "aggr_func" : aggr_func, "join_func" : join_func}
report = list()
for k in algs:
elapsed_time = timeit.timeit(algs[k], number=rounds) / rounds
report.append((k, elapsed_time))
report_df = spark.createDataFrame(report, ["alg", "avg_time"]).orderBy("avg_time")
display(report_df)
The results showed that for a dataset of 1000000 rows and 100 executions the average execution time was is by 25%-30% lower for the string based processing (trim_func, join_func) methods .
Without being sure about the exact reason, I can assume that the extra process time occurs from the complexity of the aggregate function itself. In any case, it seems that performance difference it is considerable.
The test was executed under the databricks community edition cluster/notebook.
Using getItem()
, and a double when().otherwise()
for each item in the array, the reduce iterate over the array backwards build a negative counter of zeros. When the first non-zero value is encountered, the counter becomes positive and the counting stops. The reduce starts the counter with a pseudo count of -1, which is removed at the end.
import pyspark.sql.functions as F
from functools import reduce
cols = [F.col('myArray').getItem(index) for index in range(ARRAY_LENGTH-1, -1, -1)]
trailing_count_column = F.abs(reduce(lambda col1, col2: F.when((col1 < 0) & (col2 != 0), -col1).othewise(
F.when((col1 < 0) & (col2 == 0), col1 - 1).otherwise(col1)), cols, F.lit(-1))) - 1
df = df.withColumn('trailingZeroes', trailing_count_column)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.