简体   繁体   中英

Is matplotlib scatter plot slow for large number of data?

I have a dataset which contains attribute x, y and they can be plotted in xy surface.

Originally, I use the code

df.plot(kind='scatter', x='x', y='y', alpha=0.10, s=2)
plt.gca().set_aspect('equal')

The code is pretty quick with data size about 50000.

Recently, I use a newer dataset, with size about 2500000. And the scatter plot becomes much slower.

I want to know, if it's an expected behavior, and if there is anything I can do to improve the plot speed?

Yes, it is. The reason for that is that a scatterplot of more than maybe a thousand points makes very little sense, so no one bothered to optimise it. You will be better off using some other representation for your data:

  • A heatmap if your points are distributed all over the place. Make heatmap cells pretty small
  • Draw some sort of a curve that approximates a distribution, maybe correlate your y with your x. Be sure to provide some confidence values or describe a distribution in other way; for me, for instance, building a box-with-whiskers of y for every x (or a range of x ) and placing them on the same grid usually works pretty well.
  • Reduce your dataset. @sascha in comments suggests random sampling, and that's definitely a good idea. Depending on your data, maybe there is a better way to choose representative points.

I had same problem with more than 300k 2D coordinates from a dimension reduction algorithm and the solution was be approximate that coordinates into a 2D numpy array and visualize it as an image. The result was pretty good and also much faster:

def plot_to_buf(data, height=2800, width=2800, inc=0.3):
    xlims = (data[:,0].min(), data[:,0].max())
    ylims = (data[:,1].min(), data[:,1].max())
    dxl = xlims[1] - xlims[0]
    dyl = ylims[1] - ylims[0]

    print('xlims: (%f, %f)' % xlims)
    print('ylims: (%f, %f)' % ylims)

    buffer = np.zeros((height+1, width+1))
    for i, p in enumerate(data):
        print('\rloading: %03d' % (float(i)/data.shape[0]*100), end=' ')
        x0 = int(round(((p[0] - xlims[0]) / dxl) * width))
        y0 = int(round((1 - (p[1] - ylims[0]) / dyl) * height))
        buffer[y0, x0] += inc
        if buffer[y0, x0] > 1.0: buffer[y0, x0] = 1.0
    return xlims, ylims, buffer

data = load_data() # data.shape = (310216, 2) <<< your data here
xlims, ylims, I = plot_to_buf(data, height=h, width=w, inc=0.3)
ax_extent = list(xlims)+list(ylims)
plt.imshow(I,
           vmin=0,
           vmax=1, 
           cmap=plt.get_cmap('hot'),
           interpolation='lanczos',
           aspect='auto',
           extent=ax_extent
           )
plt.grid(alpha=0.2)
plt.title('Latent space')
plt.colorbar()

here is the result:

I hope this helps you.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM