簡體   English   中英

檢查 scikit-learn 管道中的特征重要性

[英]Inspection of the feature importance in scikit-learn pipelines

我使用 scikit-learn 定義了以下管道:

model_lg = Pipeline([("preprocessing", StandardScaler()), ("classifier", LogisticRegression())])
model_dt = Pipeline([("preprocessing", StandardScaler()), ("classifier", DecisionTreeClassifier())])
model_gb = Pipeline([("preprocessing", StandardScaler()), ("classifier", HistGradientBoostingClassifier())])

然后我使用交叉驗證來評估每個 model 的性能:

cv_results_lg = cross_validate(model_lg, data, target, cv=5, return_train_score=True, return_estimator=True)
cv_results_dt = cross_validate(model_dt, data, target, cv=5, return_train_score=True, return_estimator=True)
cv_results_gb = cross_validate(model_gb, data, target, cv=5, return_train_score=True, return_estimator=True)

當我嘗試使用coef_方法檢查每個 model 的特征重要性時,它給我一個歸因錯誤:

model_lg.steps[1][1].coef_
AttributeError: 'LogisticRegression' object has no attribute 'coef_'

model_dt.steps[1][1].coef_
AttributeError: 'DecisionTreeClassifier' object has no attribute 'coef_'

model_gb.steps[1][1].coef_
AttributeError: 'HistGradientBoostingClassifier' object has no attribute 'coef_'

我想知道,我該如何解決這個錯誤? 或者是否有任何其他方法來檢查每個 model 中的特征重要性?

Imo,這里的要點如下。 一方面,管道實例model_lgmodel_dt等未明確安裝(您沒有直接在它們上調用方法.fit() ),這會阻止您嘗試訪問實例本身的coef_屬性。

另一方面,通過使用參數return_estimator=True調用.cross_validate() (僅在交叉驗證方法中使用.cross_validate()是可能的),您可以為每個 cv 拆分返回擬合估計量,但您應該訪問他們通過你的字典cv_results_lgcv_results_dt等(在'estimator'鍵上)。 這是代碼中的參考,這是一個示例:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_validate

X, y = load_iris(return_X_y=True)

model_lg = Pipeline([("preprocessing", StandardScaler()), ("classifier", LogisticRegression())])

cv_results_lg = cross_validate(model_lg, X, y, cv=5, return_train_score=True, return_estimator=True)

這些將是 - 例如 - 在第一次折疊時計算的結果。

cv_results_lg['estimator'][0].named_steps['classifier'].coef_

有關相關主題的有用見解可以在以下位置找到:

在某些算法和打印精度中進行循環

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM