簡體   English   中英

PySpark DataFrame - 每個分區的“非空集群”數

[英]PySpark DataFrame - Number of "non-null clusters" per partition

首先,讓我首先提供一個示例數據框以進行說明。 我有一個包含兩列的數據框。 在創建它的代碼下方:

df1_l = [
  (0, 1),
  (0, 2),
  (0, 3),
  (0, 4),
  (0, None),
  (0, None),
  (0, None),
  (0, 801),
  (0, 802),
  (0, 803),
  (0, None),
  (0, None),
  (1, 1),
  (1, 2),
  (1, 3),
  (1, 4),
  (1, None),
  (1, None),
  (1, None),
  (1, 801),
  (1, 802),
  (1, 803),
  (1, None),
  (1, None)
]

df1 = spark.createDataFrame(df1_l, schema = ["id", "val"])
df1.show()

數據框如下所示:

+---+----+
| id| val|
+---+----+
|  0|   1|
|  0|   2|
|  0|   3|
|  0|   4|
|  0|null|
|  0|null|
|  0|null|
|  0| 801|
|  0| 802|
|  0| 803|
|  0|null|
|  0|null|
|  1|   1|
|  1|   2|
|  1|   3|
|  1|   4|
|  1|null|
|  1|null|
|  1|null|
|  1| 801|
+---+----+
  • id是我用來划分多個窗口函數的列。
  • val是一列包含空值和數值的值。

目標:我想使用一個新列來計算每個分區中列val中非空集群的數量,該列為給定集群的所有元素提供相同的整數值。 集群是任何一組具有不同於null的值的連續行(1 個與 null 不同的孤立行也構成一個集群)。

換句話說,所需的輸出將如下(列n_cluster

+---+----+---------+
| id| val|n_cluster|
+---+----+---------+
|  0|   1|        1|
|  0|   2|        1|
|  0|   3|        1|
|  0|   4|        1|
|  0|null|     null|
|  0|null|     null|
|  0|null|     null|
|  0| 801|        2|
|  0| 802|        2|
|  0| 803|        2|
|  0|null|     null|
|  0|null|     null|
|  1|   1|        1|
|  1|   2|        1|
|  1|   3|        1|
|  1|   4|        1|
|  1|null|     null|
|  1|null|     null|
|  1|null|     null|
|  1| 801|        2|
+---+----+---------+

有人可以幫我創建列 n_cluster 嗎? .

注意:以上只是一個玩具示例。 每個分區可以有多個大於 2 的簇。列“n_cols”應按照示例中的說明對它們進行編號。

提前致謝

下面的代碼實現了我的意圖,並使用示例數據調用了該函數:

def cluster_ids(df_data: DataFrame, 
                partition_by: List[str],
                val_column: str,
                ts_column: str) -> DataFrame:

  cumsum_column = "cumsum"
  window_cumsum = (
                    Window.partitionBy(*partition_by)
                          .orderBy(F.asc(ts_column))
                          .rowsBetween(Window.unboundedPreceding, Window.currentRow)
                  )

  window_rows = (
              Window.partitionBy(*partition_by)
                    .orderBy(F.asc(cumsum_column))
              )

  marker = func.when(func.col(val_column).isNull(), 1).otherwise(0)
  cumsum = F.sum(marker).over(window_cumsum)

  df_data_cumsum = df_data.withColumn(cumsum_column, cumsum)
  df_cluster_ids = (df_data_cumsum
                          .filter(func.col(val_column).isNotNull())
                          .select(*partition_by, cumsum_column).dropDuplicates()
                          .withColumn("cluster_id", func.row_number().over(window_rows))
                   )
  
  result = (df_data_cumsum.join(df_cluster_ids,
                               on = [*partition_by, cumsum_column],
                               how = "left")
                          .withColumn("cluster_id", 
                                      func.when(func.col("val").isNotNull(), func.col("cluster_id")))
                          .drop(cumsum_column)
           )

  return result

res = cluster_ids(df_data = df1,
                  partition_by = ["id"],
                  ts_column = "row",
                  val_column = "val")

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM