簡體   English   中英

Spark DataFrame CountVectorizedModel 錯誤數據類型字符串

[英]Spark DataFrame CountVectorizedModel Error With DataType String

我有以下代碼嘗試執行一個簡單的操作,我試圖將稀疏向量轉換為密集向量。 這是我到目前為止所擁有的:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoder}
import org.apache.spark.ml.feature.CountVectorizerModel
import org.apache.spark.mllib.linalg.Vector
import spark.implicits._

// Identify how many distinct values are in the OCEAN_PROXIMITY column
val distinctOceanProximities = dfRaw.select(col("ocean_proximity")).distinct().as[String].collect()

val cvmDF = new CountVectorizerModel(tags)
  .setInputCol("ocean_proximity")
  .setOutputCol("sparseFeatures")
  .transform(dfRaw)
  
val exprs = (0 until distinctOceanProximities.size).map(i => $"features".getItem(i).alias(s"$distinctOceanProximities(i)"))
val vecToSeq = udf((v: Vector) => v.toArray)

val df2 = cvmDF.withColumn("features", vecToSeq($"sparseFeatures")).select(exprs:_*)
df2.show()

當我運行此腳本時,我收到以下錯誤:

java.lang.IllegalArgumentException: requirement failed: Column ocean_proximity must be of type equal to one of the following types: [array<string>, array<string>] but was actually of type string.
  at scala.Predef$.require(Predef.scala:281)
  at org.apache.spark.ml.util.SchemaUtils$.checkColumnTypes(SchemaUtils.scala:63)
  at org.apache.spark.ml.feature.CountVectorizerParams.validateAndTransformSchema(CountVectorizer.scala:97)
  at org.apache.spark.ml.feature.CountVectorizerParams.validateAndTransformSchema$(CountVectorizer.scala:95)
  at org.apache.spark.ml.feature.CountVectorizerModel.validateAndTransformSchema(CountVectorizer.scala:272)
  at org.apache.spark.ml.feature.CountVectorizerModel.transformSchema(CountVectorizer.scala:338)
  at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:71)
  at org.apache.spark.ml.feature.CountVectorizerModel.transform(CountVectorizer.scala:306)
  ... 101 elided

我認為它期待數據類型的字符串序列,但我只有一個字符串。 任何想法如何解決這一問題?

這很簡單。 我所要做的就是將列從字符串轉換為字符串數組,如下所示:

val oceanProximityAsArrayDF = dfRaw.withColumn("ocean_proximity", array("ocean_proximity"))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM