简体   繁体   中英

How to develop a REST API using an ML model trained on Apache Spark?

Assume this scenario:

We analyze the data, train some machine learning models using whatever tool we have at hand, and save those models. This is done in Python, using Apache Spark python shell and API. We know Apache Spark is good at batch processing, hence a good choice for the aboce scenario.

Now going into production, for each given request, we need to return a response which depends also on the output of the trained model. This is, I assume, what people call stream processing , and Apache Flink is usually recommended for it. But how would you use the same models trained using tools available in Python, in a Flink pipeline?

The micro-batch mode of Spark wouldn't work here, since we really need to respond to each request, and not in batches.

I've also seen some libraries trying to do machine learning in Flink, but that doesn't satisfy needs of people who have diverse tools in Python and not Scala, and are not even familiar with Scala.

So the question is, how do people approach this problem?

This question is related, but not a duplicate, since the author there mentions explicitly using Spark MLlib. That library runs on JVM, and has more potential to be ported to other JVM based platforms. But here the question is how would people approach it if let say, they were using scikit-learn , or GPy or whatever other method/package they use.

I needed a way of creating a custom Transformer for an ml Pipeline and have that custom object be saved/loaded along with the rest of the pipeline. This led me to digging into the very ugly depths of spark model serialisation / deserialisation. In short it looks like all the spark ml models have two components metadata and model data where model data is what ever parameters were learned during .fit() . The metadata is saved in a directory called metadata under the model save dir and as far as I can tell is json so that shouldn't be an issue. The model parameters themselves seem to be saved just as a parquet file in the save dir. This is the implementation for saving an LDA model

override protected def saveImpl(path: String): Unit = {
      DefaultParamsWriter.saveMetadata(instance, path, sc)
      val oldModel = instance.oldLocalModel
      val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration,
        oldModel.topicConcentration, oldModel.gammaShape)
      val dataPath = new Path(path, "data").toString
      sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
    }

notice the sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) on the last line. So the good news is that you could load file into your webserver, and if the server is in Java/Scala you'll just need to keep the spark jars in the classpath.

If however you're using say python for the webserver you could use a parquet library for python ( https://github.com/jcrobak/parquet-python ) the bad news is that some or all of the objects in the parquet file are going to be binary Java dumps so you can't actually read them in python. A few options come to mind, use Jython (meh), use Py4J to load the objects, this is what pyspark uses to communicate with the JVM so this could actually work. I wouldn't expect this to be exactly straightforward though.

Or from the linked question use jpmml-spark and hope for the best.

Have a look at MLeap .

We have had some success at externalizing the model learned on Spark into separate services which provide prediction on the new incoming data. We externalized the LDA topic modelling pipeline, albeit for in Scala. But they do have python support so it's worth looking at.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM