简体   繁体   中英

Cross validation dataset folds for Random Forest feature importance

I am trying to generate random forest's feature importance plot using cross validation folds. When only feature (X) and target(y) data is used, the implementation is straightforward such as:

rfc = RandomForestClassifier()
rfc.fit(X, y)
importances = pd.DataFrame({'FEATURE':data_x.columns,'IMPORTANCE':np.round(rfc.feature_importances_,3)})
importances = importances.sort_values('IMPORTANCE',ascending=False).set_index('FEATURE')

print(importances)
importances.plot.bar()
plt.show()

which yields: 在此处输入图片说明

However, how could I transform this code in order to create a similar plot for every cross-validation folds (k-fold) that I would be creating ?

The code that I have at the moment is:

# Empty list storage to collect all results for displaying as plots
mylist = []

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold
kf = KFold(n_splits=3)
for train, test in kf.split(X, y):
    train_data = np.array(X)[train]
    test_data = np.array(y)[test]
for rfc = RandomForestClassifier():
    rfc.fit(train_data, test_data)

For example, the above code creates (3 folds) using the Cross validation technique and my aim is to create feature importance plots for all the 3 folds, resulting in 3 feature importance plot graphs. At the moment, it's giving me loop errors.

I am not sure what would be the most efficient technique to use each of the created (k-folds) to generate feature importance graph via Random forest respectively, for each of the (k-folds).

One of the cause of the error is this code rfc.fit(train_data, test_data) . You should put train labels as 2nd argument, not test data.

As for the plotting, you can try to do something like code below. I assume you are aware that in this case k-folds CV is only used to select different sets of training data. The test data is ignored because there is no prediction made:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold
from sklearn.datasets import make_classification

# dummy classification dataset
X, y = make_classification(n_features=10)
# dummy feature names
feature_names = ['F{}'.format(i) for i in range(X.shape[1])]

kf = KFold(n_splits=3)
rfc = RandomForestClassifier()
count = 1
# test data is not needed for fitting
for train, _ in kf.split(X, y):
    rfc.fit(X[train, :], y[train])
    # sort the feature index by importance score in descending order
    importances_index_desc = np.argsort(rfc.feature_importances_)[::-1]
    feature_labels = [feature_names[i] for i in importances_index_desc]
    # plot
    plt.figure()
    plt.bar(feature_labels, rfc.feature_importances_[importances_index_desc])
    plt.xticks(feature_labels, rotation='vertical')
    plt.ylabel('Importance')
    plt.xlabel('Features')
    plt.title('Fold {}'.format(count))
    count = count + 1
plt.show()

This is the code that worked for me:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import KFold
from sklearn.datasets import make_classification

# classification dataset
data_x, data_y = make_classification(n_features=9)

# feature names must be declared outside the function
# feature_names = list(data_x.columns)

kf = KFold(n_splits=10)
rfc = RandomForestClassifier()
count = 1
# test data is not needed for fitting
for train, _ in kf.split(data_x, data_y):
    rfc.fit(data_x[train, :], data_y[train])
    # sort the feature index by importance score in descending order
    importances_index_desc = np.argsort(rfc.feature_importances_)[::-1]
    feature_labels = [feature_names[-i] for i in importances_index_desc]

    # plot
    plt.figure()
    plt.bar(feature_labels, rfc.feature_importances_[importances_index_desc])
    plt.xticks(feature_labels, rotation='vertical')
    plt.ylabel('Importance')
    plt.xlabel('Features')
    plt.title('Fold {}'.format(count))
    count = count + 1
plt.show()

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM