簡體   English   中英

使用Java的Spark MLlib分類輸入格式

[英]Spark MLlib classification input format using Java

如何將DTO列表轉換為Spark ML輸入數據集格式

我有DTO:

public class MachineLearningDTO implements Serializable {
    private double label;
    private double[] features;

    public MachineLearningDTO() {
    }

    public MachineLearningDTO(double label, double[] features) {
        this.label = label;
        this.features = features;
    }

    public double getLabel() {
        return label;
    }

    public void setLabel(double label) {
        this.label = label;
    }

    public double[] getFeatures() {
        return features;
    }

    public void setFeatures(double[] features) {
        this.features = features;
    }
}

和代碼:

Dataset<MachineLearningDTO> mlInputDataSet = spark.createDataset(mlInputData, Encoders.bean(MachineLearningDTO.class));
LogisticRegression logisticRegression = new LogisticRegression();
LogisticRegressionModel model = logisticRegression.fit(MLUtils.convertMatrixColumnsToML(mlInputDataSet));

執行代碼后,我得到:

java.lang.IllegalArgumentException:要求失敗:列要素必須為org.apache.spark.ml.linalg.VectorUDT@3bfc3ba7類型,但實際上為ArrayType(DoubleType,false)。

如果使用代碼將其更改為org.apache.spark.ml.linalg.VectorUDT:

VectorUDT vectorUDT = new VectorUDT();
vectorUDT.serialize(Vectors.dense(......));

然后我得到:

java.lang.UnsupportedOperationException:無法推斷類org.apache.spark.ml.linalg.VectorUDT的類型,因為它與bean不兼容

在org.apache.spark.sql.catalyst.JavaTypeInference $ .org $ apache $ spark $ sql $ catalyst $ JavaTypeInference $$ serializerFor(JavaTypeInference.scala:437)處

我已經弄清楚了,以防萬一有人也會堅持使用它,我編寫了簡單的轉換器,它可以工作:

private Dataset<Row> convertToMlInputFormat(List< MachineLearningDTO> data) {
    List<Row> rowData = data.stream()
            .map(dto ->
                    RowFactory.create(dto.getLabel() ? 1.0d : 0.0d, Vectors.dense(dto.getFeatures())))
            .collect(Collectors.toList());
    StructType schema = new StructType(new StructField[]{
            new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
            new StructField("features", new VectorUDT(), false, Metadata.empty()),
    });

    return spark.createDataFrame(rowData, schema);
}

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM