[英]PySpark sampleBy using multiple columns
我想從 PySpark 上的數據框中進行分層抽樣。 有一個sampleBy(col, fractions, seed=None)
函數,但它似乎只使用一列作為分層。 有沒有辦法將多列用作分層?
基於這里的答案
將其轉換為 python 后,我認為答案可能如下所示:
#create a dataframe to use
df = sc.parallelize([ (1,1234,282),(1,1396,179),(2,8620,178),(3,1620,191),(3,8820,828) ] ).toDF(["ID","X","Y"])
#we are going to use the first two columns as our key (strata)
#assign sampling percentages to each key # you could do something cooler here
fractions = df.rdd.map(lambda x: (x[0],x[1])).distinct().map(lambda x: (x,0.3)).collectAsMap()
#setup how we want to key the dataframe
kb = df.rdd.keyBy(lambda x: (x[0],x[1]))
#create a dataframe after sampling from our newly keyed rdd
#note, if the sample did not return any values you'll get a `ValueError: RDD is empty` error
sampleddf = kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns)
sampleddf.show()
+---+----+---+
| ID| X| Y|
+---+----+---+
| 1|1234|282|
| 1|1396|179|
| 3|1620|191|
+---+----+---+
#other examples
kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns).show()
+---+----+---+
| ID| X| Y|
+---+----+---+
| 2|8620|178|
+---+----+---+
kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns).show()
+---+----+---+
| ID| X| Y|
+---+----+---+
| 1|1234|282|
| 1|1396|179|
+---+----+---+
這是您要找的那種東西嗎?
James Tobin 的上述解決方案適用於所提供的示例,但我在我的數據集(近 200 萬條記錄)上復制該方法時遇到了困難。 發生了奇怪的與 java 相關的運行時錯誤,我無法查明問題所在(我在本地模式下運行 pyspark)。
另一種方法是靈活使用基於單列的分層抽樣方法。 為此,我們創建了一個新的(臨時)列,它是我們最初想要應用分層抽樣的多個列中存在的值的合並。 然后,我們執行拆分並刪除結果拆分中的merged column
。
def get_stratified_split_multiple_columns(input_df, col_name1, col_name2, seed_value=random_seed_value, train_frac=0.6):
'''
Following the approach of stratified sampling based on a single column as presented at
https://stackoverflow.com/a/47672336/530399 .
However, this time our single column is going to be a merger
of the values present the multiple columns (`col_name1` and `col_name2`).
Note that pyspark split is not exact. Therefore, if there are too few examples per category, it could be that none of the examples go to validation/test split and result in error.
'''
merged_col_name = "both_labels"
input_df = input_df.withColumn(merged_col_name, F.concat(F.col(col_name1), F.lit('_#_@_#_'),
F.col(col_name2))) # The "_#_@_#_" acts as a separator between the values.
fractions1 = input_df.select(merged_col_name).distinct().withColumn("fraction",
F.lit(train_frac)).rdd.collectAsMap()
train_df = input_df.stat.sampleBy(merged_col_name, fractions1, seed_value)
valid_and_test_df = input_df.exceptAll(train_df)
fractions2 = {key: 0.5 for key, value in fractions1.items()} # 0.5 for equal split of valid and test set
valid_df = valid_and_test_df.stat.sampleBy(merged_col_name, fractions2, seed_value)
test_df = valid_and_test_df.exceptAll(valid_df)
# Delete the merged_col_name from all splits
train_df = train_df.drop(merged_col_name)
valid_df = valid_df.drop(merged_col_name)
test_df = test_df.drop(merged_col_name)
return train_df, valid_df, test_df
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.