简体   繁体   中英

Stratified sampling with Spark and Java

I'd like to make sure I'm training on a stratified sample of my data.

It seems this is supported by Spark 2.1 and earlier versions via JavaPairRDD.sampleByKey(...) and JavaPairRDD.sampleByKeyExact(...) as explained here .

But: My data is stored in a Dataset<Row> , not a JavaPairRDD . The first column is the label, all others are features (imported from a libsvm-formatted file).

What's the easiest way to get a stratified sample of my dataset instance and at the end have a Dataset<Row> again?

In a way this question is related to Dealing with unbalanced datasets in Spark MLlib .

This possible duplicate does not mention Dataset<Row> at all, neither is it in Java. It does not answer my question.

Ok, since the answer of the question here was actually not intended for Java , I have rewritten it in Java .

The reasoning is still the same thought. We are still using sampleByKeyExact . There is no out of the box miracle features for now ( spark 2.1.0 )

So here you go :

package org.awesomespark.examples;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.*;
import scala.Tuple2;

import java.util.Map;

public class StratifiedDatasets {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder()
                .appName("Stratified Datasets")
                .getOrCreate();

        Dataset<Row> data = spark.read().format("libsvm").load("sample_libsvm_data.txt");

        JavaPairRDD<Double, Row> rdd = data.toJavaRDD().keyBy(x -> x.getDouble(0));
        Map<Double, Double> fractions = rdd.map(Tuple2::_1)
                .distinct()
                .mapToPair((PairFunction<Double, Double, Double>) (Double x) -> new Tuple2(x, 0.8))
                .collectAsMap();

        JavaRDD<Row> sampledRDD = rdd.sampleByKeyExact(false, fractions, 2L).values();
        Dataset<Row> sampledData = spark.createDataFrame(sampledRDD, data.schema());

        sampledData.show();
        sampledData.printSchema();
    }
}

Now let's package and submit :

$ sbt package
[...]
// [success] Total time: 2 s, completed Jan 16, 2017 1:45:51 PM

$ spark-submit --class org.awesomespark.examples.StratifiedDatasets target/scala-2.10/java-stratified-dataset_2.10-1.0.jar 
[...]

// +-----+--------------------+
// |label|            features|
// +-----+--------------------+
// |  0.0|(692,[127,128,129...|
// |  1.0|(692,[158,159,160...|
// |  1.0|(692,[124,125,126...|
// |  1.0|(692,[152,153,154...|
// |  1.0|(692,[151,152,153...|
// |  0.0|(692,[129,130,131...|
// |  1.0|(692,[99,100,101,...|
// |  0.0|(692,[154,155,156...|
// |  0.0|(692,[127,128,129...|
// |  1.0|(692,[154,155,156...|
// |  0.0|(692,[151,152,153...|
// |  1.0|(692,[129,130,131...|
// |  0.0|(692,[154,155,156...|
// |  1.0|(692,[150,151,152...|
// |  0.0|(692,[124,125,126...|
// |  0.0|(692,[152,153,154...|
// |  1.0|(692,[97,98,99,12...|
// |  1.0|(692,[124,125,126...|
// |  1.0|(692,[156,157,158...|
// |  1.0|(692,[127,128,129...|
// +-----+--------------------+
// only showing top 20 rows

// root
//  |-- label: double (nullable = true)
//  |-- features: vector (nullable = true)

For python users, you can also check my answer Stratified sampling with pyspark .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM