简体   繁体   中英

Using model.predict (Keras + TF) in multiprocessing

I have following problem. I'm using a Tensorflow Keras model to evaluate continuous sensor data. My input for my model consists of 15 sensor data frames. Because the function model.predict() takes near 1 second I wanted to execute this function asynchronous so that I can collect the next data frames in this time period. To accomplish this I created a Pool with the multiprocessing libary and a function to for model.predict. My code looks something like this:

def predictData(data): 
   return model.predict(data)

global model
model = tf.keras.models.load_model("Network.h5")
model._make_predict_function()

p = Pool(processes = 4)
...
res = p.apply_async(predictData, ([[iinput]],))
print(res.get(timeout = 10))

Now I always get a timeout-error when calling predictData(). It seems like model.predict() is not working right. What am I making wrong?

It is possible to run multiple predictions in multiple concurrent python processes, only you have to build inside each independent process its own tensorflow computational graph and then call the keras.model.predict

Write a function which you will use with the multiprocessing module (with the Process or Pool class), within this function you should build your model, tensorflow graph and whatever you need, set all tensorflow and keras variables, then you can call the predict method on it, and then pipe the result back to your master process.

for example:

    def f(data):

          import tensorflow, keras

          configure your tensorflow and keras settings (e.g.  GPU/CPU usage)

          keras_model = build_your_keras_model()

          result = keras_model.predict(data)

          return result

    if __main__ = '__main__':

          p = Pool(processes = 4)

          res = p.apply_async(f, (data,))

          print(res.get(timeout = 10))

The reason is that each process you spawn will require a new initialized version of your model which it uses to make predictions. Therefore you have to make sure you instantiate/load your model for every spawned process. This is defiantly not optimal.

This is a known caveat with multiprocessing machine learning training and/or inference. Some libraries come with multiprocessing features out-of-the-box and provide parallizable calls to their models. However, in most libraries once you want to do multiprocessing, you are on your own!

Make sure you instantiate your model once and then find a way to share that model across processes. One basic way to do that, is to serve your model as a flask service then make predictions against that service to your hearts content. Cheers!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM