So I'm trying to get an exponential curve for some COVID data, but I can't seem to get my curve_fit function to show any sort of curve whatsoever. It's so bad it perfectly overlaps the regression line seaborn generated in my graph.
I've tried making both my date and case data smaller/bigger before throwing it into the curve_fit function, but I still either get a similar line and/or an Optimization error. I even tried calculating my function manually but that was (naturally) also way off.
#Plot scatter plot for total case count
x = df_sb['date_ordinal']
y1 = df_sb['totalcountconfirmed']
y2 = df_sb['totalcountdeaths']
plt.figure(figsize=(14,10))
ax = plt.subplot(1,1,1)
# Plot scatter plot along with linear regression line
sns.regplot(x='date_ordinal', y='totalcountconfirmed', data=df_sb)
# Formatting axes
ax.set_xlim(x.min() - 1, x.max() + 10)
ax.set_ylim(0, y1.max() + 1)
ax.set_xlabel('Date')
labels = [dt.date.fromordinal(int(item)) for item in ax.get_xticks()]
ax.set_xticklabels(labels)
plt.xticks(rotation = 45)
plt.ylabel("Total Confirmed Cases")
# Exponential Curve
from scipy.optimize import curve_fit
from scipy.special import expit
x_data = df_sb['date_ordinal'].to_numpy()
Y_data = df_sb['totalcountconfirmed'].to_numpy()
def func(x, a, b, c):
return a * expit(-b * x) + c
popt, pcov = curve_fit(func, x_data, Y_data, maxfev=10000)
a, b, c = popt
fit_y = func(x_data, a, b, c)
plt.plot(x_data, fit_y)
plt.legend(['Total Cases (Linear)','Total Cases (Exponential)'])
# Inserting Significant Date Labels
add_sig_dates(df_sb, 'totalcountconfirmed')
plt.show()
Despite you did not give any access to the data, just by looking at the plot I'm pretty sure you mean
def func(x, a, b, c):
return a * np.exp(-b * x) + c
instead of
def func(x, a, b, c):
return a * expit(-b * x) + c
Since it's an exponential fit, I think you should provide initial guess for parameters in order to achieve good results. This can be done with the p0
argument. For example:
p0 = [2 ,1, 0] # < -- just an example, they are bad guesses
popt, pcov = curve_fit(func, x_data, Y_data, maxfev=10000, p0=p0)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.