[英]How to find the points of intersection of a line and multiple curves in Python?
[英]Plotting multiple curves and intersection points using for loop?
我目前想要 plot 多个指数函数(大约 100)在 Python 的同一个数字上。
我有指数 function 的参数a
和b
的值:
def exponenial_func(x, a, b):
return a * np.exp(-b * x)
popt1 = [8.05267292e+03, 1.48857925e+00]
x = np.linspace(0, 15, 30000)
yfun = exponenial_func(x, *popt1)
我现在希望 plot exponential_func
对于popt1
的倍数,范围从1
到10
倍,步长为0.1
。 我也想 plot 在同一张图上的nth
条和第n-1th
条曲线之间的交点(如果可能的话)。
到目前为止,我已经尝试了下面的代码,但这不起作用:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import scipy.optimize as optimize
from scipy.optimize import curve_fit
import matplotlib.pylab as pl
def exponenial_func(x, a, b):
return a * np.exp(-b * x)
x = np.linspace(0, 15, 30000)
popt1 = [8.05267292e+03, 1.48857925e+00]
# use a loop to plot multiples of popt1 applied to
# exponenial_func from 1x to 10x in steps of
# 0.1x (100 plots total) - create envelope
# --------------------------------------
# create color palette
# -----------------
n = 100
# choose a matplotlib color map
colors = pl.cm.gist_heat(np.linspace(0, 1, n))
# -----------------
fig = plt.figure(figsize=(4.5, 3.6))
ax = fig.add_subplot(1, 1, 1)
ax.set_ylim([1e2, 1e5])
ax.set_xlim([0, 1])
ax.set_yscale("log")
ax.spines['right'].set_visible(True)
ax.spines['top'].set_visible(True)
ax.spines['left'].set_visible(True)
ax.spines['bottom'].set_visible(True)
ax.set_xlabel('x')
ax.set_ylabel('y')
# enable minor ticks
ax.minorticks_on()
# put grid behind curves
ax.set_axisbelow(True)
ax.xaxis.grid(True, which='minor')
ax.xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())
# turn on major grid
ax.grid(b=True, which='major', color='black', linestyle='-', zorder=1, linewidth=0.4, alpha=0.12)
# turn on minor grid
ax.grid(b=True, which='minor', color='black', linestyle='-', zorder=1, linewidth=0.4, alpha=0.12)
ax.tick_params(direction='out', axis='both', which='both', pad=4)
ax.xaxis.set_ticks_position('bottom')
for i in np.arange(1, 10, n):
popt_i = i * popt1
# find the previous set of paramets (cannot multiply list by float?)
popt_prev = (i - 10 / n) * popt1
yfun_i = exponenial_func(x, *popt_i)
yfun_prev = exponenial_func(x, *popt_prev)
idx_i = np.argwhere(np.diff(np.sign(yfun_i - yfun_prev))).flatten()
ax.plot(x, yfun_i, zorder=1, c=colors[i], linewidth=1, alpha=1)
ax.scatter(x[idx_i], yfun_i[idx_i], s=4, alpha=1, zorder=4, color="black")
plt.savefig('test.png', dpi=300, bbox_inches='tight', format='png')
如果有人注释掉这些行:
#yfun_prev = exponenial_func(x, *popt_prev)
#idx_i = np.argwhere(np.diff(np.sign(yfun_i - yfun_prev))).flatten()
#ax.scatter(x[idx_i], yfun_i[idx_i], s=4, alpha=1, zorder=4, color="black")
为了删除查看交点的代码的任何部分,代码运行,但我只得到它作为 plot:
如果代码中保留以上几行,则会出现如下错误:
File "envelope.py", line 63, in <module>
popt_prev = (i - 10 / n) * popt1
TypeError: can't multiply sequence by non-int of type 'numpy.float64'
有谁知道如何在 Python 中实现这一点?
下面的代码总结了我的所有评论,修复了小错误并修改了原始代码中的一些视觉参数,结果如下图所示:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.pylab as pl
def fx(x, a, b):
return a * np.exp(-b * x)
n = 100
x = np.linspace(0, 15, 30000)
popt1 = np.array([8.05267292e+03, 1.48857925e+00])
colorspace = pl.cm.prism(np.linspace(0, 1, n))
fig = plt.figure(figsize=(4.5, 3.6))
ax = fig.add_subplot(1, 1, 1)
ax.set_ylim([1e2, 1e5])
ax.set_xlim([0, 1])
ax.set_yscale("log")
ax.spines[:].set_visible(True)
ax.spines[:].set_linewidth(0.4)
ax.set_xlabel('x')
ax.set_ylabel('f(x) = a * exp(-b * x)')
ax.minorticks_on()
ax.set_axisbelow(True)
ax.xaxis.grid(True, which='minor')
ax.xaxis.set_minor_formatter(matplotlib.ticker.NullFormatter())
ax.grid(True, which='major', alpha=.75, color='pink', ls='solid', lw=.5, zorder=1)
ax.grid(True, which='minor', alpha=.75, color='pink', ls='solid', lw=.5, zorder=1)
ax.tick_params(direction='out', axis='both', pad=4, width=.5, which='both')
ax.xaxis.set_ticks_position('bottom')
for i in range(n):
popt_i = np.multiply(popt1, 1 + i // 10)
popt_prev = np.multiply(popt1, i - 10 / n)
yfun_i = fx(x, *popt_i)
yfun_prev = fx(x, *popt_prev)
idx_i = np.argwhere(np.diff(np.sign(yfun_i - yfun_prev))).flatten()
ax.plot(x, yfun_i, c=colorspace[i], lw=.25, zorder=2)
ax.scatter(x[idx_i], yfun_i[idx_i], s=.5, zorder=3)
plt.savefig('test.png', dpi=300, bbox_inches='tight', format='png')
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.