简体   繁体   中英

PySpark sampleBy using multiple columns

I want to carry out a stratified sampling from a data frame on PySpark. There is a sampleBy(col, fractions, seed=None) function, but it seems to only use one column as a strata. Is there any way to use multiple columns as a strata?

based on the answer here

after converting it to python, I think an answer might look like:

#create a dataframe to use
df = sc.parallelize([ (1,1234,282),(1,1396,179),(2,8620,178),(3,1620,191),(3,8820,828) ] ).toDF(["ID","X","Y"])

#we are going to use the first two columns as our key (strata)
#assign sampling percentages to each key # you could do something cooler here
fractions = df.rdd.map(lambda x: (x[0],x[1])).distinct().map(lambda x: (x,0.3)).collectAsMap()

#setup how we want to key the dataframe
kb = df.rdd.keyBy(lambda x: (x[0],x[1]))

#create a dataframe after sampling from our newly keyed rdd
#note, if the sample did not return any values you'll get a `ValueError: RDD is empty` error

sampleddf = kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns)
sampleddf.show()
+---+----+---+
| ID|   X|  Y|
+---+----+---+
|  1|1234|282|
|  1|1396|179|
|  3|1620|191|
+---+----+---+
#other examples
kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns).show()
+---+----+---+
| ID|   X|  Y|
+---+----+---+
|  2|8620|178|
+---+----+---+


kb.sampleByKey(False,fractions).map(lambda x: x[1]).toDF(df.columns).show()
+---+----+---+
| ID|   X|  Y|
+---+----+---+
|  1|1234|282|
|  1|1396|179|
+---+----+---+

Is this the kind of thing you were looking for?

James Tobin's solution above works fine with the example presented but I had difficulties replicating the approach on my dataset (nearly 2 million records). Strange java related runtime error occured and I was not able to pinpoint what the issue was (I was running pyspark in local mode).

An alternative approach would be to flex the approach of stratified sampling based on a single column. For this, we create a new (temporary) column which is a merger of the values present in the multiple columns on which we originally wanted to apply stratified sampling. Then, we perform the splits and delete the merged column in the resulting splits.

def get_stratified_split_multiple_columns(input_df, col_name1, col_name2, seed_value=random_seed_value, train_frac=0.6):
    '''
    Following the approach of stratified sampling based on a single column as presented at
    https://stackoverflow.com/a/47672336/530399 .
    However, this time our single column is going to be a merger
    of the values present the multiple columns (`col_name1` and `col_name2`).

    Note that pyspark split is not exact. Therefore, if there are too few examples per category, it could be that none of the examples go to validation/test split and result in error.
    '''

    merged_col_name = "both_labels"
    input_df = input_df.withColumn(merged_col_name, F.concat(F.col(col_name1), F.lit('_#_@_#_'),
                                                             F.col(col_name2)))  # The "_#_@_#_" acts as a separator between the values.
    
    fractions1 = input_df.select(merged_col_name).distinct().withColumn("fraction",
                                                                        F.lit(train_frac)).rdd.collectAsMap()
    train_df = input_df.stat.sampleBy(merged_col_name, fractions1, seed_value)

    valid_and_test_df = input_df.exceptAll(train_df)
    fractions2 = {key: 0.5 for key, value in fractions1.items()}  # 0.5 for equal split of valid and test set

    valid_df = valid_and_test_df.stat.sampleBy(merged_col_name, fractions2, seed_value)
    test_df = valid_and_test_df.exceptAll(valid_df)

    # Delete the merged_col_name from all splits
    train_df = train_df.drop(merged_col_name)
    valid_df = valid_df.drop(merged_col_name)
    test_df = test_df.drop(merged_col_name)

    return train_df, valid_df, test_df

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM