简体   繁体   中英

How to plot feature importance for random forest in python

I have created a random forest model, and would like to plot the feature importances

model_RF_tune = RandomForestClassifier(random_state=0, n_estimators = 80, 
min_samples_split =10, max_depth= None, max_features = "auto",)

I have tried defining a function:

def plot_feature_importances_health(model):
    n_features = model.data.shape
    plt.barh(range(n_features), model.feature_importances_, align = "center")
    plt.yticks(np.arrange(n_features), df_health_reconstructed.feature_names)
    plt.xlabel("Feature importance")
    plt.ylabel("Feature")
    plt.ylim(-1, n_features)

but this plot_feature_importances_health(model_RF_tune)

Gives this result: AttributeError: 'RandomForestClassifier' object has no attribute 'data'

How do I plot it correctly?

Not all models can execute model.data . Would you like to try my codes instead? However, the codes plot the top 10 features only.

# use RandomForestClassifier to look for important key features
n = 10    # choose top n features
rfc = RandomForestClassifier(random_state=SEED, n_estimators=200, max_depth=3)
rfc_model = rfc.fit(X, y)

(pd.Series(rfc_model.feature_importances_, index=X.columns)
    .nlargest(n)
    .plot(kind='barh', figsize=[8, n/2.5],color='navy')
    .invert_yaxis())    # most important feature is on top, ie, descending order

ticks_x = np.linspace(0, 0.5, 6)   # (start, end, number of ticks)
plt.xticks(ticks_x, fontsize=15, color='black')
plt.yticks(size=15, color='navy' )
plt.title('Top Features derived by RandomForestClassifier', family='fantasy', size=15)
print(list((pd.Series(rfc_model.feature_importances_, index=X.columns).nlargest(n)).index))

This one seems to work for me

%matplotlib inline
#do code to support model
#"data" is the X dataframe and model is the SKlearn object
feats = {} # a dict to hold feature_name: feature_importance
for feature, importance in zip(dataframe_name.columns, 
model_name.feature_importances_):
     feats[feature] = importance #add the name/value pair 


importances = pd.DataFrame.from_dict(feats, orient='index').rename(columns={0: 'Gini- 
importance'})
importances.sort_values(by='Gini-importance').plot(kind='barh', 
color="SeaGreen",figsize=(10,8))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM