简体   繁体   中英

How can I update matplotlib subplot axes in an interactive plot when the subplots are viewing different axis regions?

I want to use the interactive subplots in matplotlib and have the changes I make (scrolling, cropping) apply to every subplot. Usually, I would use

ax1 = plt.subplot(112)   
ax2 = plt.subplot(212, sharex=ax1)   

but I do not want the axes to be identical. I want them to be dependent, but able to display different x values. For example, I would want my initial plot to look like this

3 subplots showing a sine curve with slightly offset x-axes in each

And then I would be able to use the interactive plot tools to scroll across, such that each moves an equal amount in the x-direction.

I have tried

a,b = ax1.get_xlim()
ax2 = plt.subplot(212,xlim=(a-1,b-1))

but the get_xlim() does not get live, interactive updates to a and b , it only gets them once when the plot is first made.

I can come up with two approaches, each with their drawbacks.

Approach 1: hook into xlim_changed and ylim_changed axis callbacks

This approach works for panning, zooming, scroll events, and the home/forward/backward buttons.

However, only one axis can control the other axis or axes, as connecting all axes to each other results in infinite recursion on triggering the callback.

import numpy as np
import matplotlib.pyplot as plt

class XAxisUpdater(object):

    def __init__(self, master, slave):
        self.master = master
        self.slave = slave
        self.old_limits = self.get_limits(self.master)

    def __call__(self, dummy):
        new_limits = self.get_limits(self.master)
        deltas = new_limits - self.old_limits
        self.set_limits(self.slave, self.get_limits(self.slave) + deltas)
        self.old_limits = new_limits

    def get_limits(self, ax):
        return np.array(ax.get_xlim())

    def set_limits(self, ax, limits):
        ax.set_xlim(limits)


class YAxisUpdater(XAxisUpdater):

    def get_limits(self, ax):
        return np.array(ax.get_ylim())

    def set_limits(self, ax, limits):
        ax.set_ylim(limits)


if __name__ == '__main__':

    fig, (ax1, ax2) = plt.subplots(1, 2)
    ax1.set_xlim(0, 10)
    ax2.set_xlim(5, 15)
    ax1.callbacks.connect('xlim_changed', XAxisUpdater(ax1, ax2))
    ax1.callbacks.connect('ylim_changed', YAxisUpdater(ax1, ax2))
    plt.show()

Approach 2: detect panning (and zooming) events

This doesn't result in recursion. However, this approach breaks if the home, the forward, or the backward buttons is used.

Additionally, my implementation below doesn't trigger on zoom for reasons I can't discern (even though it should, IMO). Maybe someone else can spot and correct the issue.

import numpy as np
import matplotlib.pyplot as plt


class PanEvent(object):
    def __init__(self, ax, func, button=1):
        self.ax = ax
        self.func = func
        self.button = button
        self.press = False
        self.move = False
        self.ax.figure.canvas.mpl_connect('button_press_event', self.on_press)
        self.ax.figure.canvas.mpl_connect('button_release_event', self.on_release)
        self.ax.figure.canvas.mpl_connect('motion_notify_event', self.on_move)

    def on_press(self, event):
        if (event.inaxes == self.ax) & (event.button == self.button):
            self.press = True

    def on_move(self, event):
        if (event.inaxes == self.ax) & self.press:
            self.func()

    def on_release(self, event):
        if event.inaxes == self.ax:
            self.press = False


class AxisUpdater(object):

    def __init__(self, axes):
        self.axes = axes
        self.limits = np.array([ax.axis() for ax in self.axes])

    def __call__(self):
        for ax, old_limits in zip(self.axes, self.limits):
            deltas = np.array(ax.axis()) - old_limits
            if not np.all(np.isclose(deltas, 0)):
                break

        for ii, (ax, old_limits) in enumerate(zip(self.axes, self.limits)):
            self.limits[ii] = old_limits + deltas
            ax.axis(self.limits[ii])


if __name__ == '__main__':

    fig, axes = plt.subplots(1, 2)
    axes[0].set_xlim(0, 10)
    axes[1].set_xlim(5, 15)
    instance1 = PanEvent(axes[0], AxisUpdater(axes)) # NB: you have to keep a reference to the class, otherwise it will be GCed.
    instance2 = PanEvent(axes[1], AxisUpdater(axes))
    plt.show()
 

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM