簡體   English   中英

使用Spark和Java進行分層抽樣

[英]Stratified sampling with Spark and Java

我想確保我正在對我的數據進行分層抽樣培訓。

看來,這是由星火2.1和更早版本支持通過JavaPairRDD.sampleByKey(...)JavaPairRDD.sampleByKeyExact(...)作為解釋在這里

但是:我的數據存儲在Dataset<Row> ,而不是JavaPairRDD 第一列是標簽,所有其他都是功能(從libsvm格式的文件導入)。

獲取數據集實例的分層樣本的最簡單方法是什么?最后再次有Dataset<Row>

在某種程度上,這個問題與在Spark MLlib中處理不平衡數據集有關

這個可能的副本根本沒有提到Dataset<Row> ,也沒有提到Java。 它沒有回答我的問題。

好的,既然這里的問題的答案實際上不是針對Java的 ,那么我已經用Java重寫了它。

推理仍然是同樣的想法。 我們仍在使用sampleByKeyExact 現在沒有開箱即用的奇跡功能( 火花2.1.0

所以你走了:

package org.awesomespark.examples;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.*;
import scala.Tuple2;

import java.util.Map;

public class StratifiedDatasets {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder()
                .appName("Stratified Datasets")
                .getOrCreate();

        Dataset<Row> data = spark.read().format("libsvm").load("sample_libsvm_data.txt");

        JavaPairRDD<Double, Row> rdd = data.toJavaRDD().keyBy(x -> x.getDouble(0));
        Map<Double, Double> fractions = rdd.map(Tuple2::_1)
                .distinct()
                .mapToPair((PairFunction<Double, Double, Double>) (Double x) -> new Tuple2(x, 0.8))
                .collectAsMap();

        JavaRDD<Row> sampledRDD = rdd.sampleByKeyExact(false, fractions, 2L).values();
        Dataset<Row> sampledData = spark.createDataFrame(sampledRDD, data.schema());

        sampledData.show();
        sampledData.printSchema();
    }
}

現在讓我們打包並提交:

$ sbt package
[...]
// [success] Total time: 2 s, completed Jan 16, 2017 1:45:51 PM

$ spark-submit --class org.awesomespark.examples.StratifiedDatasets target/scala-2.10/java-stratified-dataset_2.10-1.0.jar 
[...]

// +-----+--------------------+
// |label|            features|
// +-----+--------------------+
// |  0.0|(692,[127,128,129...|
// |  1.0|(692,[158,159,160...|
// |  1.0|(692,[124,125,126...|
// |  1.0|(692,[152,153,154...|
// |  1.0|(692,[151,152,153...|
// |  0.0|(692,[129,130,131...|
// |  1.0|(692,[99,100,101,...|
// |  0.0|(692,[154,155,156...|
// |  0.0|(692,[127,128,129...|
// |  1.0|(692,[154,155,156...|
// |  0.0|(692,[151,152,153...|
// |  1.0|(692,[129,130,131...|
// |  0.0|(692,[154,155,156...|
// |  1.0|(692,[150,151,152...|
// |  0.0|(692,[124,125,126...|
// |  0.0|(692,[152,153,154...|
// |  1.0|(692,[97,98,99,12...|
// |  1.0|(692,[124,125,126...|
// |  1.0|(692,[156,157,158...|
// |  1.0|(692,[127,128,129...|
// +-----+--------------------+
// only showing top 20 rows

// root
//  |-- label: double (nullable = true)
//  |-- features: vector (nullable = true)

對於python用戶,您還可以檢查我的答案使用pyspark進行分層抽樣

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM