[英]spark dataframe format error when using saved model to predict on new data
我能夠訓練模型並保存模型(Train.scala)。 現在,我想使用訓練有素的模型對新數據進行預測(Predict.scala)。
我創建了一個新的VectorAssembler
在我Predict.scala
到特征化的新數據。 我應該使用相同的VectorAssembler
在Train.scala
為Predict.scala
文件? 因為我看到轉換后的要素數據類型有問題。
例如:當我讀取訓練有素的模型並嘗試預測已完成的新數據時,出現此錯誤:
type mismatch;
[error] found : org.apache.spark.sql.DataFrame
[error] (which expands to) org.apache.spark.sql.Dataset[org.apache.spark.sql.Row]
[error] required: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] => org.apache.spark.sql.Dataset[?]
[error] val predictions = model.transform(featureData)
培訓代碼:Train.scala
// assembler
val assembler = new VectorAssembler()
.setInputCols(feature_list)
.setOutputCol("features")
//read in train data
val trainingData = spark
.read
.parquet(train_data_path)
// generate training features
val trainingFeatures = assembler.transform(trainingData)
//define model
val lightGBMClassifier = new LightGBMClassifier()
.setLabelCol("label")
.setFeaturesCol("features")
.setIsUnbalance(true)
.setMaxDepth(25)
.setNumLeaves(31)
.setNumIterations(100)
// fit model
val lgbm = lightGBMClassifier.fit(trainingFeatures)
//save model
lgbm
.write
.overwrite()
.save(my_model_s3_path)
預測代碼:Predict.scala
val assembler = new VectorAssembler()
.setInputCols(feature_list)
.setOutputCol("features")
// load model
val model = spark.read.parquet(my_model_s3_path)
// load new data
val inputData = spark.read.parquet(new_data_path)
//Assembler to transform new data
val featureData = assembler.transform(inputData)
//predict on new data
val predictions = model.transform(featureData) ### <- got error here
我是否應該使用其他方法讀取訓練有素的模型或轉換數據?
“我應該在Train.scala中為Predict.scala文件使用相同的VectorAssembler嗎?” 是的,但是,我強烈建議您使用Pipelines 。
// Train.scala
val pipeline = new Pipeline().setStages(Array(assembler, lightGBMClassifier))
val pipelineModel = pipeline.fit(trainingData)
pipelineModel.write.overwrite().save("/path/to/pipelineModel")
// Predict.scala
val pipelineModel = PipelineModel.load("/path/to/pipelineModel")
val predictions = pipelineModel.transform(inputData)
看看問題是否消失了,只需使用管道,正確地對模型進行序列化/反序列化以及更好地組織代碼即可。 另外,請確保trainingData和inputData都包含在feature_list中列出的相同列。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.