简体   繁体   中英

Using Keras for real-time training and predicting

I want to use Keras for a real-time training and prediction setting. In my scenario I get real-time data via MQTT that should be used to train a (LSTM) Neural Network and/or to apply them to the to get a prediction.

I am using the Tensorflow backend with GPU support and fairly potent GPU capacity but in my scenario Keras does not really profit from GPU acceleration. (I did some performance tests by using the examples in the keras repository to make sure that the GPU acceleration works in general). In my first approach, I used the model.train_on_batch(...) method to train the network with each item coming via MQTT:

model = load_model()

def on_message(msg):
    """
    Method called by MQTT client each time new data comes in
    """

    if msg.topic == 'my/topic':
        X, Y = prepare_data(msg.payload)

        prediction = model.predict(X)
        loss = model.train_on_batch(X, Y)

        send_to_visualization_tool(prediction, loss)

One training step in this setting takes about 200ms. However, when I introduce a buffer eg buffering 100 data points, the training time for the whole batch only increases slightly. This suggests that the setup time for batch training has a huge overhead. I also noticed that when using size 1 batches, the CPU consumption is quite high, while the GPU is hardly used at all.

As an alternative I now introduced a synchronized Queue, where the MQTT client pushes data, whenever data comes in and the Neural Network then consumes all data as a batch, that came in while processing the previous batch:

train_data_queue = Queue.Queue()

# MQTT client running in separate thread
def on_message(msg):
    train_data_queue.put(msg.payload)

model = load_model()

while True:
    train_data_batch = dequeue_all(train_data_queue)  # dequeue all items from queue
                                                      # or block until at least one
                                                      # item is present
    X, Y = prepare_data(train_data_batch)

    predictions = model.predict_on_batch(X)
    losses = model.train_on_batch(X, Y)

    send_to_visualization_tool(predictions, losses)

This approach works okay but it would be nice if I could get rid of the additional complexity of synchronized Queues and multi threading. Ie get first approach work.

My question therefore is: Is there a way to reduce the overhead of one batch trainings? Eg by reimplementing the model in pure tensorflow? Or can you think of a better way to do real-time training with Keras?

The performance of keras should be broadly similar to the performance of raw tensorflow, so I do not recommend rewriting your model.

Indeed modern hardware usually takes about the same time to train with a single example as it does with a batch of examples, which is why we spend so much effort batching things up. You can get rid of the complexity of synchronized queues if you want to use tf.contrib.batching.batch_function but you'll still need to feed it from many threads if you want to get the extra throughput.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM