簡體   English   中英

確保 PySpark 數組中相鄰元素之間的差異大於給定的最小值

[英]Ensure difference between adjoining elements in PySpark array is more than a given minimum value

我有一個包含三列的 PySpark 數據框( df )。

1. category :一些字符串

2. startTimeArray :它是一個按升序包含時間戳的數組。

3. endTimeArray :它是一個按升序包含時間戳的數組。

在每一行中, startTimeArray中的數組長度與endTimeArray的數組長度相同。 對於這些數組中的每個索引, startTimeArray給出的時間戳比endTimeArray對應的(相同索引)時間戳少(發生在前一個日期)。

startTimeArray列(和endTimeArray列)中,數組的長度可以不同。

以下是數據框的示例:

+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+
|category|startTimeArray                                                                                           |endTimeArray                                                                                             |
+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+
|a       |[2019-01-10 00:00:00, 2019-01-12 00:00:00, 2019-01-16 00:00:00, 2019-01-20 00:00:00]                     |[2019-01-11 00:00:00, 2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-01-22 00:00:00]                     |
|a       |[2019-03-11 00:00:00, 2019-03-18 00:00:00, 2019-03-20 00:00:00, 2019-03-25 00:00:00, 2019-03-27 00:00:00]|[2019-03-16 00:00:00, 2019-03-19 00:00:00, 2019-03-23 00:00:00, 2019-03-26 00:00:00, 2019-03-30 00:00:00]|
|b       |[2019-01-14 00:00:00, 2019-01-16 00:00:00, 2019-02-22 00:00:00]                                          |[2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-02-25 00:00:00]                                          |
+--------+---------------------------------------------------------------------------------------------------------+---------------------------------------------------------------------------------------------------------+

在每一行的startTimeArray列中,我想確保數組中連續元素(連續索引處的元素)之間的差異至少為三天。 如果startTimeArray的一行有n元素,我願意刪除數組中的條目,但第一個條目除外。 此外,如果從startTimeArray的一行中刪除索引 i 處的元素,我希望從endTimeArray的同一行中刪除索引 i-1 處的元素。**

如何使用 PySpark 完成此任務?

有幾點,我們需要注意:

  1. 如果startTimeArray的數組有一個元素,我們就讓它在那里。

  2. 我意識到可以通過刪除startTimeArray中數組中第一個元素之后的所有元素來實現此任務。 那將是微不足道的情況。 但我想通過盡可能少的刪除來完成任務。

以下是我在上面給出的示例數據幀df的情況下想要的輸出。

+--------+---------------------------------------------------------------+---------------------------------------------------------------+
|category|startTimeArray                                                 |endTimeArray                                                   |
+--------+---------------------------------------------------------------+---------------------------------------------------------------+
|a       |[2019-01-10 00:00:00, 2019-01-16 00:00:00, 2019-01-20 00:00:00]|[2019-01-15 00:00:00, 2019-01-18 00:00:00, 2019-01-22 00:00:00]|
|a       |[2019-03-11 00:00:00, 2019-03-18 00:00:00, 2019-03-25 00:00:00]|[2019-03-16 00:00:00, 2019-03-23 00:00:00, 2019-03-30 00:00:00]|
|b       |[2019-01-14 00:00:00, 2019-02-22 00:00:00]                     |[2019-01-18 00:00:00, 2019-02-25 00:00:00]                     |
+--------+---------------------------------------------------------------+---------------------------------------------------------------+

用戶定義函數 (UDF) 可以完成這項工作。 雖然它比原生 Spark sql 函數帶來了性能損失,但它清楚地表達了所需的操作。

from datetime import date, timedelta

from pyspark.sql.functions import *
from pyspark.sql.types import *

d = [date(2019, 1, d) for d in (10, 12, 16, 20)]
e = [date(2019, 1, d) for d in (11, 15, 18, 22)]
f = [date(2019, 3, d) for d in (11, 18, 20, 25, 27)]
g = [date(2019, 3, d) for d in (16, 19, 23, 26, 30)]
h = [date(2019, 1, 14), date(2019, 1, 16), date(2019, 2, 22)]
i = [date(2019, 1, 15), date(2019, 1, 18), date(2019, 2, 25)]

df = spark.createDataFrame((("a", d, e), ("a", f, g), ("b", h, i)),
                           schema=("category", "startDates", "endDates"))


@udf(returnType=ArrayType(ArrayType(DateType())))
def retain_dates_n_days_apart(startDates, endDates, min_apart=3):
    start_dates = [startDates[0]]
    end_dates = []
    for start, end in zip(startDates[1:], endDates):
        if start >= start_dates[-1] + timedelta(days=min_apart):
            start_dates.append(start)
            end_dates.append(end)
    end_dates.append(endDates[-1])
    return start_dates, end_dates


df2 = (df
       .withColumn("foo",
                   retain_dates_n_days_apart(df.startDates,
                                             df.endDates))
       .cache())

(df2.withColumn("startDates", df2.foo.getItem(0))
 .withColumn("endDates", df2.foo.getItem(1))
 .drop("foo")
 ).show(truncate=False)
# +--------+------------------------------------+------------------------------------+
# |category|startDates                          |endDates                            |
# +--------+------------------------------------+------------------------------------+
# |a       |[2019-01-10, 2019-01-16, 2019-01-20]|[2019-01-15, 2019-01-18, 2019-01-22]|
# |a       |[2019-03-11, 2019-03-18, 2019-03-25]|[2019-03-16, 2019-03-23, 2019-03-30]|
# |b       |[2019-01-14, 2019-02-22]            |[2019-01-18, 2019-02-25]            |
# +--------+------------------------------------+------------------------------------+

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM