简体   繁体   中英

Improve performance of matplotlib for subset of data

I have a little PyQT4 application which shows plots for a big data set (100k points x 14 channels). I just want to show a period of 128 points and click to show the next period.

My naive approach was to create the figures and plot only a subset of my data on each step in the loop. This leads to a loading time for quite a second and I thought this may be to much for this task.

Is there any way to improve the performance? Did I miss some matplotlib built-in functions to plot only a subset of data? I wouldn't mind a longer loading time at the beginning of the application, so maybe I could plot it all and zoom in?

EDIT: Provided a simple running example

Took 7.39s to plot 8 samples on my machine

import time

import matplotlib.pyplot as plt
import numpy as np


plt.ion()

num_channels = 14
num_samples = 1024

data = np.random.rand(num_channels, num_samples)

figure = plt.figure()

start = 0
period = 128

axes = []
for i in range(num_channels):
    axes.append(figure.add_subplot(num_channels, 1, i+1))

end = start+period
x_values = [x for x in range(start, end)]

begin = time.time()
num_plot = 0
for i in range(0, num_samples, period):
    num_plot += 1
    end = start+period

    for i, ax in enumerate(axes):
        ax.hold(False)
        ax.plot(x_values, data[i][start:end], '-')
        ax.set_ylabel(i)
    start += period

    figure.canvas.draw()
print("Took %.2fs to plot %d samples" % (time.time()-begin, num_plot))

Using the @joe-kington answer from here: How to update a plot in matplotlib improved performance to a decent value.

I now only change the y-values of the line object using set_ydata() . The line object is returned when calling ax.plot() which is only called once.

EDIT: Added a running example: Took 3.11s to plot 8 samples on my machine

import time

import matplotlib.pyplot as plt
import numpy as np


plt.ion()

num_channels = 14
num_samples = 1024

data = np.random.rand(num_channels, num_samples)

figure = plt.figure()

start = 0
period = 128

axes = []
for i in range(num_channels):
    axes.append(figure.add_subplot(num_channels, 1, i+1))

end = start+period
x_values = [x for x in range(start, end)]

lines = []
begin = time.time()
num_plot = 1 # first plot
for i, ax in enumerate(axes):
    ax.hold(False)

    # save the line object
    line, = ax.plot(x_values, data[i][start:end], '-')
    lines.append(line)

    ax.set_xlim([start,end])
    ax.set_ylabel(i)

start += period
for _ in range(period, num_samples, period):
    num_plot += 1
    end = start + period
    for i, line in enumerate(lines):
        line.set_ydata(data[i][start:end])
    start += period

    figure.canvas.draw()
print("Took %.2fs to plot %d samples" % (time.time()-begin, num_plot))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM