[英]Python: finding the intersection point of two gaussian curves
I have two gaussian plots:我有两个高斯图:
x = np.linspace(-5,9,10000)
plot1=plt.plot(x,mlab.normpdf(x,2.5,1))
plot2=plt.plot(x,mlab.normpdf(x,5,1))
and I want to find the point at where the two curves intersect.我想找到两条曲线相交的点。 Is there a way of doing this?有没有办法做到这一点? In particular I want to find the value of the x-coordinate where they meet.特别是我想找到它们相遇的 x 坐标的值。
You want to find the x's such that both gaussian functions have the same height.(ie intersect)您想找到 x 使得两个高斯函数具有相同的高度。(即相交)
You can do so by equating two gaussian functions and solve for x.您可以通过将两个高斯函数相等并求解 x 来实现。 In the end you will get a quadratic equation with coefficients relating to the gaussian means and variances.最后,您将得到一个二次方程,其中包含与高斯均值和方差相关的系数。 Here is the final result:这是最终结果:
import numpy as np
def solve(m1,m2,std1,std2):
a = 1/(2*std1**2) - 1/(2*std2**2)
b = m2/(std2**2) - m1/(std1**2)
c = m1**2 /(2*std1**2) - m2**2 / (2*std2**2) - np.log(std2/std1)
return np.roots([a,b,c])
m1 = 2.5
std1 = 1.0
m2 = 5.0
std2 = 1.0
result = solve(m1,m2,std1,std2)
The output is :输出是:
array([ 3.75])
You can plot the found intersections:您可以绘制找到的交叉点:
x = np.linspace(-5,9,10000)
plot1=plt.plot(x,mlab.normpdf(x,m1,std1))
plot2=plt.plot(x,mlab.normpdf(x,m2,std2))
plot3=plt.plot(result,mlab.normpdf(result,m1,std1),'o')
The plot will be:情节将是:
If your gaussians have multiple intersections, the code will also find all of them(say m1=2.5, std1=3.0, m2=5.0, std2=1.0):如果你的高斯有多个交集,代码也会找到所有的交集(比如 m1=2.5, std1=3.0, m2=5.0, std2=1.0):
Here's a solution based on purely numpy that is also applicable to curves other than Gaussian.这是一个基于纯 numpy 的解决方案,它也适用于高斯以外的曲线。
def get_intersection_locations(y1,y2,test=False,x=None):
"""
return indices of the intersection point/s.
"""
idxs=np.argwhere(np.diff(np.sign(y1 - y2))).flatten()
if test:
x=range(len(y1)) if x is None else x
plt.figure(figsize=[2.5,2.5])
ax=plt.subplot()
ax.plot(x,y1,color='r',label='line1',alpha=0.5)
ax.plot(x,y2,color='b',label='line2',alpha=0.5)
_=[ax.axvline(x[i],color='k') for i in idxs]
_=[ax.text(x[i],ax.get_ylim()[1],f"{x[i]:1.1f}",ha='center',va='bottom') for i in idxs]
ax.legend(bbox_to_anchor=[1,1])
ax.set(xlabel='x',ylabel='density')
return idxs
# single intersection
x = np.arange(-10, 10, 0.001)
y1=sc.stats.norm.pdf(x,-2,2)
y2=sc.stats.norm.pdf(x,2,3)
get_intersection_locations(y1=y1,y2=y2,x=x,test=True) # returns indice/s array([10173])
# double intersection
x = np.arange(-10, 10, 0.001)
y1=sc.stats.norm.pdf(x,-2,1)
y2=sc.stats.norm.pdf(x,2,3)
get_intersection_locations(y1=y1,y2=y2,x=x,test=True)
Based on an answer to a similar question .基于对类似问题的回答。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.