簡體   English   中英

PySpark 示例通過使用多個列

[英]PySpark sampleBy using multiple columns

我想從 PySpark 上的數據框中進行分層抽樣。 有一個sampleBy(col, fractions, seed=None)函數,但它似乎只使用一列作為分層。 有沒有辦法將多列用作分層?

基於這里的答案

將其轉換為 python 后,我認為答案可能如下所示:

#create a dataframe to use
df = sc.parallelize([ (1,1234,282),(1,1396,179),(2,8620,178),(3,1620,191),(3,8820,828) ] ).toDF(["ID","X","Y"])

#we are going to use the first two columns as our key (strata)
#assign sampling percentages to each key # you could do something cooler here
fractions = df.rdd.map(lambda x: (x[0],x[1])).distinct().map(lambda x: (x,0.3)).collectAsMap()

#setup how we want to key the dataframe
kb = df.rdd.keyBy(lambda x: (x[0],x[1]))

#create a dataframe after sampling from our newly keyed rdd
#note, if the sample did not return any values you'll get a `ValueError: RDD is empty` error

sampleddf = kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns)
sampleddf.show()
+---+----+---+
| ID|   X|  Y|
+---+----+---+
|  1|1234|282|
|  1|1396|179|
|  3|1620|191|
+---+----+---+
#other examples
kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns).show()
+---+----+---+
| ID|   X|  Y|
+---+----+---+
|  2|8620|178|
+---+----+---+


kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns).show()
+---+----+---+
| ID|   X|  Y|
+---+----+---+
|  1|1234|282|
|  1|1396|179|
+---+----+---+

這是您要找的那種東西嗎?

James Tobin 的上述解決方案適用於所提供的示例,但我在我的數據集(近 200 萬條記錄)上復制該方法時遇到了困難。 發生了奇怪的與 java 相關的運行時錯誤,我無法查明問題所在(我在本地模式下運行 pyspark)。

另一種方法是靈活使用基於單列的分層抽樣方法。 為此,我們創建了一個新的(臨時)列,它是我們最初想要應用分層抽樣的多個列中存在的值的合並。 然后,我們執行拆分並刪除結果拆分中的merged column

def get_stratified_split_multiple_columns(input_df, col_name1, col_name2, seed_value=random_seed_value, train_frac=0.6):
    '''
    Following the approach of stratified sampling based on a single column as presented at
    https://stackoverflow.com/a/47672336/530399 .
    However, this time our single column is going to be a merger
    of the values present the multiple columns (`col_name1` and `col_name2`).

    Note that pyspark split is not exact. Therefore, if there are too few examples per category, it could be that none of the examples go to validation/test split and result in error.
    '''

    merged_col_name = "both_labels"
    input_df = input_df.withColumn(merged_col_name, F.concat(F.col(col_name1), F.lit('_#_@_#_'),
                                                             F.col(col_name2)))  # The "_#_@_#_" acts as a separator between the values.
    
    fractions1 = input_df.select(merged_col_name).distinct().withColumn("fraction",
                                                                        F.lit(train_frac)).rdd.collectAsMap()
    train_df = input_df.stat.sampleBy(merged_col_name, fractions1, seed_value)

    valid_and_test_df = input_df.exceptAll(train_df)
    fractions2 = {key: 0.5 for key, value in fractions1.items()}  # 0.5 for equal split of valid and test set

    valid_df = valid_and_test_df.stat.sampleBy(merged_col_name, fractions2, seed_value)
    test_df = valid_and_test_df.exceptAll(valid_df)

    # Delete the merged_col_name from all splits
    train_df = train_df.drop(merged_col_name)
    valid_df = valid_df.drop(merged_col_name)
    test_df = test_df.drop(merged_col_name)

    return train_df, valid_df, test_df

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM