简体   繁体   中英

How can we calculate accuracy for the Random forest classifier if we are using 4 label classification?

I am trying to predict the quality attributes of the product which have been sold over last decade. Based on the likes/dislikes i have kept the 4 label for the product Labels are: bad, good, very good,very bad

I have downloaded the last decade data and categorized the samples in these 4 labels. When i put the input in random forest classifier, it is giving the valid result and giving the feature importance:

Here is the code for same:

classifier = RandomForestClassifier(
        n_estimators=100, n_jobs=6, oob_score=True, random_state=50,
        max_features="auto", min_samples_leaf=50
    )
    '''

    classifier = RandomForestClassifier(
        n_estimators=100, n_jobs=6, oob_score=True, random_state=50#, max_depth=3
    )

I just want to understand, how we can calculate the accuracy of the model as it has 4 labels.

There are a few accuracies you can check to assess model quality; the first is the overall model accuracy (how many did it get right). For this you can simply use the sklearn accuracy score

from sklearn.metrics import accuracy_score
accuracy_score(y_true, y_pred)

Of course this is not giving you enough information about which class is being miss-classified and to what (eg. it might be more acceptable to categorize very good as good rather than bad). For this you need a confusion matrix

from sklearn.metrics import confusion_matrix
confusion_matrix(y_true, y_pred)

You probably also want to look into recall and precision as they will help understand the matrix and quantify it. What you can also do since your labels are ranked, is to convert them to int values and adress the problem with a regression instead of a classification (then convert the outputs back to ints). This way the model will get an understanding of order, therefore you get an ordinal classification.

EDIT:

Just in case the answer is not clear, you get y_pred in the following way:

classifier.fit(x_train, y_train)
y_pred = classifier.predict(x_val)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM