[英]PySpark DataFrame - Number of "non-null clusters" per partition
First, let me start by providing a sample dataframe for illustration purposes.首先,让我首先提供一个示例数据框以进行说明。 I have a dataframe with two columns.
我有一个包含两列的数据框。 Below the code to create it:
在创建它的代码下方:
df1_l = [
(0, 1),
(0, 2),
(0, 3),
(0, 4),
(0, None),
(0, None),
(0, None),
(0, 801),
(0, 802),
(0, 803),
(0, None),
(0, None),
(1, 1),
(1, 2),
(1, 3),
(1, 4),
(1, None),
(1, None),
(1, None),
(1, 801),
(1, 802),
(1, 803),
(1, None),
(1, None)
]
df1 = spark.createDataFrame(df1_l, schema = ["id", "val"])
df1.show()
The dataframe looks as follows:数据框如下所示:
+---+----+
| id| val|
+---+----+
| 0| 1|
| 0| 2|
| 0| 3|
| 0| 4|
| 0|null|
| 0|null|
| 0|null|
| 0| 801|
| 0| 802|
| 0| 803|
| 0|null|
| 0|null|
| 1| 1|
| 1| 2|
| 1| 3|
| 1| 4|
| 1|null|
| 1|null|
| 1|null|
| 1| 801|
+---+----+
id
is the column I use to partition several window functions.id
是我用来划分多个窗口函数的列。val
is a column of values containing both nulls and numeric values.val
是一列包含空值和数值的值。 Goal : I want to count the number of non-null clusters in the column val within each partition using a new column that gives the same integer value to all elements of a given cluster.目标:我想使用一个新列来计算每个分区中列val中非空集群的数量,该列为给定集群的所有元素提供相同的整数值。 A cluster is any set of consecutive rows with values different than
null
(1 isolated row different than null also constitutes a cluster).集群是任何一组具有不同于
null
的值的连续行(1 个与 null 不同的孤立行也构成一个集群)。
In other words, the desired output would be the following (column n_cluster
)换句话说,所需的输出将如下(列
n_cluster
)
+---+----+---------+
| id| val|n_cluster|
+---+----+---------+
| 0| 1| 1|
| 0| 2| 1|
| 0| 3| 1|
| 0| 4| 1|
| 0|null| null|
| 0|null| null|
| 0|null| null|
| 0| 801| 2|
| 0| 802| 2|
| 0| 803| 2|
| 0|null| null|
| 0|null| null|
| 1| 1| 1|
| 1| 2| 1|
| 1| 3| 1|
| 1| 4| 1|
| 1|null| null|
| 1|null| null|
| 1|null| null|
| 1| 801| 2|
+---+----+---------+
Could somebody help me create the column n_cluster?有人可以帮我创建列 n_cluster 吗? .
.
NOTE: the above is just a toy example.注意:以上只是一个玩具示例。 Each partition can have a number of clusters greater than 2. The column "n_cols" shall number them as clarified in the example.
每个分区可以有多个大于 2 的簇。列“n_cols”应按照示例中的说明对它们进行编号。
Thanks in advance提前致谢
Below a code that achieves just what I intended and a call to that function using the sample data:下面的代码实现了我的意图,并使用示例数据调用了该函数:
def cluster_ids(df_data: DataFrame,
partition_by: List[str],
val_column: str,
ts_column: str) -> DataFrame:
cumsum_column = "cumsum"
window_cumsum = (
Window.partitionBy(*partition_by)
.orderBy(F.asc(ts_column))
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
window_rows = (
Window.partitionBy(*partition_by)
.orderBy(F.asc(cumsum_column))
)
marker = func.when(func.col(val_column).isNull(), 1).otherwise(0)
cumsum = F.sum(marker).over(window_cumsum)
df_data_cumsum = df_data.withColumn(cumsum_column, cumsum)
df_cluster_ids = (df_data_cumsum
.filter(func.col(val_column).isNotNull())
.select(*partition_by, cumsum_column).dropDuplicates()
.withColumn("cluster_id", func.row_number().over(window_rows))
)
result = (df_data_cumsum.join(df_cluster_ids,
on = [*partition_by, cumsum_column],
how = "left")
.withColumn("cluster_id",
func.when(func.col("val").isNotNull(), func.col("cluster_id")))
.drop(cumsum_column)
)
return result
res = cluster_ids(df_data = df1,
partition_by = ["id"],
ts_column = "row",
val_column = "val")
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.