![](/img/trans.png)
[英]Why are the logistic regression results different between statsmodels and R?
[英]regplot logistic and statsmodels logit. Why different results?
為什么在這段代碼中,邏輯 seaborn regplot 可視化和 statsmodel logit() 分析之間的系數(截距和 x)不同? 兩條線不應該至少從同一個截距開始嗎? 我做錯了什么?
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.formula.api import logit
np.random.seed(2022) # to get the same data each time
df = pd.DataFrame({
'y': np.random.randint(2, size=10),
'x': np.random.rand(10)
})
mdl = logit("y ~ x", data=df).fit()
print(mdl.summary())
sns.regplot(y='y', x='x', data=df, logistic=True, ci=None)
plt.axline(xy1=(0, mdl.params[0]), slope=mdl.params[1], color='black')
plt.show()
Optimization terminated successfully.
Current function value: 0.665054
Iterations 5
Logit Regression Results
==============================================================================
Dep. Variable: y No. Observations: 10
Model: Logit Df Residuals: 8
Method: MLE Df Model: 1
Date: Tue, 26 Jul 2022 Pseudo R-squ.: 0.04053
Time: 07:43:10 Log-Likelihood: -6.6505
converged: True LL-Null: -6.9315
Covariance Type: nonrobust LLR p-value: 0.4535
==============================================================================
coef std err z P>|z| [0.025 0.975]
------------------------------------------------------------------------------
Intercept 2.0253 2.902 0.698 0.485 -3.663 7.713
x -2.7006 3.741 -0.722 0.470 -10.033 4.632
==============================================================================
您在sns.regplot()
plot 中看到的是概率的 plot,而不是對數(即具有估計截距和斜率的線性回歸線)。 因此,要使用logit
model 的結果匹配 plot,您必須使用截距和斜率計算每個x
值的概率值。
概率通過首先計算對數(估計的截距和斜率以及x
值的線性組合)來計算:
logits = mdl.params[0] + mdl.params[1] * df['x']
然后通過 sigmoid function 傳遞它們:
probs = np.exp(logits) / (1 + np.exp(logits))
這是兩行的完整代碼和 plot:
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.formula.api import logit
np.random.seed(2022) # to get the same data each time
df = pd.DataFrame({
'y': np.random.randint(2, size=10),
'x': np.random.rand(10)
})
mdl = logit("y ~ x", data=df).fit()
print(mdl.summary())
logits = mdl.params[0] + mdl.params[1] * df['x']
probs = np.exp(logits) / (1 + np.exp(logits))
sns.regplot(y='y', x='x', data=df, logistic=True, ci=None)
plt.plot(df['x'], probs, color='red')
plt.show()
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.