簡體   English   中英

計算一系列馬爾可夫鏈值

[英]Calculate a sequence of Markov chain values

我有一個 Spark 問題,所以對於每個實體k的輸入,我有一個概率序列p_i和一個關聯的值v_i ,例如數據可以看起來像這樣

entity | Probability | value
A      | 0.8         | 10
A      | 0.6         | 15
A      | 0.3         | 20
B      | 0.8         | 10

然后,對於實體A ,我期望平均值為0.8*10 + (1-0.8)*0.6*15 + (1-0.8)*(1-0.6)*0.3*20 + (1-0.8)*(1-0.6)*(1-0.3)*MAX_VALUE_DEFINED

我如何使用DataFrame agg func在 Spark 中實現此目的? 鑒於groupBy實體和計算結果序列的復雜性,我發現它具有挑戰性。

您可以使用 UDF 來執行此類自定義計算。 這個想法是使用collect_listA所有概率和值分組到一個地方,以便您可以循環遍歷它。 但是, collect_list不遵守記錄的順序,因此可能會導致計算錯誤。 修復它的一種方法是使用monotonically_increasing_id為每一行生成 ID

import pyspark.sql.functions as F

@F.pandas_udf('double')
def markov_udf(values):
    def markov(lst):
        # you can implement your markov logic here
        s = 0
        for i, prob, val in lst:
            s += prob
        return s
    return values.apply(markov)
    
(df
    .withColumn('id', F.monotonically_increasing_id())
    .groupBy('entity')
    .agg(F.array_sort(F.collect_list(F.array('id', 'probability', 'value'))).alias('values'))
    .withColumn('markov', markov_udf('values'))
    .show(10, False)
)

+------+------------------------------------------------------+------+
|entity|values                                                |markov|
+------+------------------------------------------------------+------+
|B     |[[3.0, 0.8, 10.0]]                                    |0.8   |
|A     |[[0.0, 0.8, 10.0], [1.0, 0.6, 15.0], [2.0, 0.3, 20.0]]|1.7   |
+------+------------------------------------------------------+------+

可能有更好的解決方案,但我認為這可以滿足您的需求。

from pyspark.sql import functions as F, Window as W
df = spark.createDataFrame(
    [('A', 0.8, 10),
     ('A', 0.6, 15),
     ('A', 0.3, 20),
     ('B', 0.8, 10)],
    ['entity', 'Probability', 'value']
)

w_desc = W.partitionBy('entity').orderBy(F.desc('value'))
w_asc = W.partitionBy('entity').orderBy('value')
df = df.withColumn('_ent_max_val', F.max('value').over(w_desc))
df = df.withColumn('_prob2', 1 - F.col('Probability'))
df = df.withColumn('_cum_prob2', F.product('_prob2').over(w_asc) / F.col('_prob2'))
df = (df.groupBy('entity')
        .agg(F.round((F.max('_ent_max_val') * F.product('_prob2')
                     + F.sum(F.col('_cum_prob2') * F.col('Probability') * F.col('value'))
             ),2).alias('mean_value'))
)
df.show()
# +------+----------+
# |entity|mean_value|
# +------+----------+
# |     A|      11.4|
# |     B|      10.0|
# +------+----------+

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM