[英]How to count the trailing zeroes in an array column in a PySpark dataframe without a UDF
我有一个数据框,其中有一列具有固定数量整数的数组。 如何将包含数组中尾随零数的列添加到 df 中? 我想避免使用 UDF 以获得更好的性能。
例如,输入 df:
>>> df.show()
+------------+
| A|
+------------+
| [1,0,1,0,0]|
| [2,3,4,5,6]|
| [0,0,0,0,0]|
| [1,2,3,4,0]|
+------------+
和一个想要的输出:
>>> trailing_zeroes(df).show()
+------------+-----------------+
| A| trailingZeroes|
+------------+-----------------+
| [1,0,1,0,0]| 2|
| [2,3,4,5,6]| 0|
| [0,0,0,0,0]| 5|
| [1,2,3,4,0]| 1|
+------------+-----------------+
将数组转换为字符串时,有几种新方法可以获得结果:
>>> from pyspark.sql.functions import length, regexp_extract, array_join, reverse
>>>
>>> df = spark.createDataFrame([(1, [1, 2, 3]),
... (2, [2, 0]),
... (3, [0, 2, 3, 10]),
... (4, [0, 2, 3, 10, 0]),
... (5, [0, 1, 0, 0, 0]),
... (6, [0, 0, 0]),
... (7, [0, ]),
... (8, [10, ]),
... (9, [100, ]),
... (10, [0, 100, ]),
... (11, [])],
... schema=("id", "arr"))
>>>
>>>
>>> df.withColumn("trailing_zero_count",
... length(regexp_extract(array_join(reverse(df.arr), ""), "^(0+)", 0))
... ).show()
+---+----------------+-------------------+
| id| arr|trailing_zero_count|
+---+----------------+-------------------+
| 1| [1, 2, 3]| 0|
| 2| [2, 0]| 1|
| 3| [0, 2, 3, 10]| 0|
| 4|[0, 2, 3, 10, 0]| 1|
| 5| [0, 1, 0, 0, 0]| 3|
| 6| [0, 0, 0]| 3|
| 7| [0]| 1|
| 8| [10]| 0|
| 9| [100]| 0|
| 10| [0, 100]| 0|
| 11| []| 0|
+---+----------------+-------------------+
从 Spark 2.4 开始,您可以使用高阶函数AGGREGATE
来做到这一点:
from pyspark.sql.functions import reverse
(
df.withColumn("arr_rev", reverse("A"))
.selectExpr(
"arr_rev",
"AGGREGATE(arr_rev, (1 AS p, CAST(0 AS LONG) AS sum), (buffer, value) -> (if(value != 0, 0, buffer.p), if(value=0, buffer.sum + buffer.p, buffer.sum)), buffer -> buffer.sum) AS result"
)
)
假设A
是您的数字数组。 这里只是要小心数据类型。 假设数组内的数字也是 long,我将初始值转换为LONG
。
对于 Spark 2.4+,您绝对应该使用aggregate
,如@David Vrba的已接受答案所示。
对于较旧的模型,这是正则表达式方法的替代方法。
首先创建一些示例数据:
import numpy as np
NROWS = 10
ARRAY_LENGTH = 5
np.random.seed(0)
data = [
(np.random.randint(0, 100, x).tolist() + [0]*(ARRAY_LENGTH-x),)
for x in np.random.randint(0, ARRAY_LENGTH+1, NROWS)
]
df = spark.createDataFrame(data, ["myArray"])
df.show()
#+--------------------+
#| myArray|
#+--------------------+
#| [36, 87, 70, 88, 0]|
#|[88, 12, 58, 65, 39]|
#| [0, 0, 0, 0, 0]|
#| [87, 46, 88, 0, 0]|
#| [81, 37, 25, 0, 0]|
#| [77, 72, 9, 0, 0]|
#| [20, 0, 0, 0, 0]|
#| [80, 69, 79, 0, 0]|
#|[47, 64, 82, 99, 88]|
#| [49, 29, 0, 0, 0]|
#+--------------------+
现在反向迭代您的列,如果列是0
则返回null
否则返回ARRAY_LENGTH-(index+1)
。 合并此结果,这将返回第一个非空索引的值 - 与尾随 0 的数量相同。
from pyspark.sql.functions import coalesce, col, when, lit,
df.withColumn(
"trailingZeroes",
coalesce(
*[
when(col('myArray').getItem(index) != 0, lit(ARRAY_LENGTH-(index+1)))
for index in range(ARRAY_LENGTH-1, -1, -1)
] + [lit(ARRAY_LENGTH)]
)
).show()
#+--------------------+--------------+
#| myArray|trailingZeroes|
#+--------------------+--------------+
#| [36, 87, 70, 88, 0]| 1|
#|[88, 12, 58, 65, 39]| 0|
#| [0, 0, 0, 0, 0]| 5|
#| [87, 46, 88, 0, 0]| 2|
#| [81, 37, 25, 0, 0]| 2|
#| [77, 72, 9, 0, 0]| 2|
#| [20, 0, 0, 0, 0]| 4|
#| [80, 69, 79, 0, 0]| 2|
#|[47, 64, 82, 99, 88]| 0|
#| [49, 29, 0, 0, 0]| 3|
#+--------------------+--------------+
另一种自 Spark 1.5.0 起有效的解决方案。 这里我们使用trim
、 rtrim
、 regexp_replace
和length
来计算尾随零:
from pyspark.sql.functions import expr
to_string_expr = expr("regexp_replace(trim('[]', string(A)), ', ', '')")
df.withColumn("str_ar", to_string_expr) \
.withColumn("trailingZeroes", expr("length(str_ar) - length(rtrim('0', str_ar))"))
# +---------------+--------------+
# | A|trailingZeroes|
# +---------------+--------------+
# |[1, 0, 1, 0, 0]| 2|
# |[2, 3, 4, 5, 6]| 0|
# |[0, 0, 0, 0, 0]| 5|
# |[1, 2, 3, 4, 0]| 1|
# +---------------+--------------+
分析:
从expr
最内到最外层元素开始:
string(A)
将数组转换为其字符串表示,即[1, 0, 1, 0, 0]
。
trim('[]', string(A))
删除前导[
和尾随]
即1, 0, 1, 0, 0
。
regexp_replace(trim('[]', string(A)), ', ', '')
中移除了,
项目之间以形成即,最终字符串表示10100
。
rtrim('0',regexp_replace(trim('[]', string(A)), ', ', ''))
修剪尾随零,即: 101
。
最后,我们得到完整字符串和修剪后的字符串的长度,然后减去它们,这将给我们零尾随长度。
更新
使用下一个代码,您可以填充一些数据(从@pault 的帖子中借用)并使用timeit
测量大型数据集的执行时间。
下面我为三个已发布的方法添加了一些基准测试。 从结果中我们可以得出结论,这些方法的性能存在一些趋势:
from pyspark.sql.functions import expr, regexp_replace, regexp_extract, reverse, length, array_join
import numpy as np
import timeit
NROWS = 1000000
ARRAY_LENGTH = 5
np.random.seed(0)
data = [
(np.random.randint(0, 9, x).tolist() + [0]*(ARRAY_LENGTH-x),)
for x in np.random.randint(0, ARRAY_LENGTH+1, NROWS)
]
df = spark.createDataFrame(data, ["A"])
def trim_func():
to_string_expr = expr("regexp_replace(trim('[]', string(A)), ', ', '')")
df.withColumn("str_ar", to_string_expr) \
.withColumn("trailingZeroes", expr("length(str_ar) - length(rtrim('0', str_ar))")) \
.show()
# Avg: 0.11089507223994588
def aggr_func():
df.withColumn("arr_rev", reverse("A")) \
.selectExpr("arr_rev", "AGGREGATE(arr_rev, (1 AS p, CAST(0 AS LONG) AS sum), \
(buffer, value) -> (if(value != 0, 0, buffer.p), \
if(value=0, buffer.sum + buffer.p, buffer.sum)), \
buffer -> buffer.sum) AS result") \
.show()
# Avg: 0.16555462517004343
def join_func():
df.withColumn("trailing_zero_count", \
length( \
regexp_extract( \
array_join(reverse(df["A"]), ""), "^(0+)", 0))) \
.show()
# Avg:0.11372986907997984
rounds = 100
algs = {"trim_func" : trim_func, "aggr_func" : aggr_func, "join_func" : join_func}
report = list()
for k in algs:
elapsed_time = timeit.timeit(algs[k], number=rounds) / rounds
report.append((k, elapsed_time))
report_df = spark.createDataFrame(report, ["alg", "avg_time"]).orderBy("avg_time")
display(report_df)
结果表明,对于 1000000 行和 100 次执行的数据集,基于字符串的处理 (trim_func, join_func) 方法的平均执行时间降低了25%-30% 。
在不确定确切原因的情况下,我可以假设额外的处理时间来自聚合函数本身的复杂性。 无论如何,似乎性能差异是相当可观的。
测试在 databricks 社区版集群/笔记本下执行。
对数组中的每个项目使用getItem()
和 double when().otherwise()
,reduce 向后迭代数组,构建一个零的负计数器。 当遇到第一个非零值时,计数器变为正数并停止计数。 reduce 以 -1 的伪计数开始计数器,最后将其删除。
import pyspark.sql.functions as F
from functools import reduce
cols = [F.col('myArray').getItem(index) for index in range(ARRAY_LENGTH-1, -1, -1)]
trailing_count_column = F.abs(reduce(lambda col1, col2: F.when((col1 < 0) & (col2 != 0), -col1).othewise(
F.when((col1 < 0) & (col2 == 0), col1 - 1).otherwise(col1)), cols, F.lit(-1))) - 1
df = df.withColumn('trailingZeroes', trailing_count_column)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.