繁体   English   中英

转换JavaRDD <Row> 到JavaRDD <Vector>

[英]Convert JavaRDD<Row> to JavaRDD<Vector>

我正在尝试在Wikipedia XML转储上执行LDA。 在获取原始文本的RDD之后,我将创建一个数据框并通过Tokenizer,StopWords和CountVectorizer管道对其进行转换。 我打算将来自CountVectorizer的Vectors输出的RDD传递给MLLib中的OnlineLDA。 这是我的代码:

 // Configure an ML pipeline
 RegexTokenizer tokenizer = new RegexTokenizer()
   .setInputCol("text")
   .setOutputCol("words");

 StopWordsRemover remover = new StopWordsRemover()
          .setInputCol("words")
          .setOutputCol("filtered");

 CountVectorizer cv = new CountVectorizer()
          .setVocabSize(vocabSize)
          .setInputCol("filtered")
          .setOutputCol("features");

 Pipeline pipeline = new Pipeline()
          .setStages(new PipelineStage[] {tokenizer, remover, cv});

// Fit the pipeline to train documents.
 PipelineModel model = pipeline.fit(fileDF);

 JavaRDD<Vector> countVectors = model.transform(fileDF)
          .select("features").toJavaRDD()
          .map(new Function<Row, Vector>() {
            public Vector call(Row row) throws Exception {
                Object[] arr = row.getList(0).toArray();

                double[] features = new double[arr.length];
                int i = 0;
                for(Object obj : arr){
                    features[i++] = (double)obj;
                }
                return Vectors.dense(features);
            }
          });

我因为行而得到类强制转换异常

Object[] arr = row.getList(0).toArray();


Caused by: java.lang.ClassCastException: org.apache.spark.mllib.linalg.SparseVector cannot be cast to scala.collection.Seq
at org.apache.spark.sql.Row$class.getSeq(Row.scala:278)
at org.apache.spark.sql.catalyst.expressions.GenericRow.getSeq(rows.scala:192)
at org.apache.spark.sql.Row$class.getList(Row.scala:286)
at org.apache.spark.sql.catalyst.expressions.GenericRow.getList(rows.scala:192)
at xmlProcess.ParseXML$2.call(ParseXML.java:142)
at xmlProcess.ParseXML$2.call(ParseXML.java:1)

我在这里找到了Scala语法来执行此操作但是找不到在Java中执行此操作的任何示例。 我尝试了row.getAs[Vector](0)但这只是Scala语法。 用Java可以做到吗?

因此,我能够通过简单地转换为Vector来做到这一点。 我不知道为什么我没有先尝试简单的事情!

         JavaRDD<Vector> countVectors = model.transform(fileDF)
              .select("features").toJavaRDD()
              .map(new Function<Row, Vector>() {
                public Vector call(Row row) throws Exception {
                    return (Vector)row.get(0);
                }
              });

您无需将DataFrame/DataSet转换为JavaRDD即可与LDA 经过几个小时的摆弄,我终于让Scala的本地rdd工作了。

相关进口:

import org.apache.spark.ml.feature.{CountVectorizer, RegexTokenizer, StopWordsRemover}
import org.apache.spark.ml.linalg.{Vector => MLVector}
import org.apache.spark.mllib.clustering.{LDA, OnlineLDAOptimizer}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.sql.{Row, SparkSession}

其余代码与下面的示例相同:

val cvModel = new CountVectorizer()
        .setInputCol("filtered")
        .setOutputCol("features")
        .setVocabSize(vocabSize)
        .fit(filteredTokens)


val countVectors = cvModel
        .transform(filteredTokens)
        .select("docId","features")
        .rdd.map { case Row(docId: String, features: MLVector) => 
                   (docId.toLong, Vectors.fromML(features)) 
                 }
val mbf = {
    // add (1.0 / actualCorpusSize) to MiniBatchFraction be more robust on tiny datasets.
    val corpusSize = countVectors.count()
    2.0 / maxIterations + 1.0 / corpusSize
  }
  val lda = new LDA()
    .setOptimizer(new OnlineLDAOptimizer().setMiniBatchFraction(math.min(1.0, mbf)))
    .setK(numTopics)
    .setMaxIterations(2)
    .setDocConcentration(-1) // use default symmetric document-topic prior
    .setTopicConcentration(-1) // use default symmetric topic-word prior

  val startTime = System.nanoTime()
  val ldaModel = lda.run(countVectors)
  val elapsed = (System.nanoTime() - startTime) / 1e9

  /**
    * Print results.
    */
  // Print training time
  println(s"Finished training LDA model.  Summary:")
  println(s"Training time (sec)\t$elapsed")
  println(s"==========")

由于去的代码的作者在这里

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM