简体   繁体   中英

Secondary y-axis in sympy

Sympy plotting has recently evolved a lot, but I cannot find a way to plot a second expression on a secondary y-axis. Below, I show an example where I combine two plots, but would like the second to use a secondary axis:

>>> from sympy import symbols
>>> from sympy.plotting import plot
>>> x = symbols('x')
>>> p1 = plot(x**4, (x,10,20), label='$x^4$', show=False, legend=True)
>>> p2 = plot(x, (x,10,20), line_color='red', label='$x$', show=False, legend= True)
>>> p1.extend(p2)
>>> p1.show()

Incidentally, is there a way to also modify the line style, eg make one of them dashed? The documentation suggests to use the _backend module for fine tuning, but it is not clear to me how to achieve this. See eg How do I use the `_backend` attribute of a Sympy `plot` .

The approach from this post can be adapted as follows. Note that sympy puts the spines at the zero-positions, which is confusing with two axes in the same plot. They can be moved again to their original position. Matplotlib can create a combined legend from the handles and labels of both axes.

from sympy import Symbol, plot, sin
import matplotlib.pyplot as plt


def move_sympyplot_to_axes(p, ax, is_twinx):
    backend = p.backend(p)
    backend.ax = ax
    backend._process_series(backend.parent._series, ax, backend.parent)
    if is_twinx:
        backend.ax.spines['left'].set_color('none')
    else:
        backend.ax.spines['right'].set_color('none')
        backend.ax.spines['left'].set_position(('axes', 0))
    backend.ax.spines['bottom'].set_position(('axes', 0))
    plt.close(backend.fig)

x = Symbol('x')
p1 = plot(x ** 4, (x, 10, 20), label='$x^4$', show=False)
p2 = plot(sin(10*x)/x**2, (x, 10, 20), adaptive=False, nb_of_points=500,
          line_color='red', label='$sin(10x)/x^2$', show=False)

fig, ax = plt.subplots()
ax2 = ax.twinx()
move_sympyplot_to_axes(p1, ax, is_twinx=False)
move_sympyplot_to_axes(p2, ax2, is_twinx=True)

ax2.tick_params(axis='y', colors='red')
ax2.set_ylabel('', color='red')
handles1, labels1 = ax.get_legend_handles_labels()
handles2, labels2 = ax2.get_legend_handles_labels()
plt.legend(handles1 + handles2, labels1 + labels2, loc='upper center')
plt.tight_layout()
plt.show()

示例图

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM