简体   繁体   中英

How to compute precision-recall in Decision tree sklearn?

I try to predict in standard dataset "iris.csv"

import pandas as pd
from sklearn import tree
df = pd.read_csv('iris.csv')
df.columns = ['X1', 'X2', 'X3', 'X4', 'Y']
df.head()

# Decision tree
from sklearn.model_selection import train_test_split
decision = tree.DecisionTreeClassifier(criterion='gini')
X = df.values[:, 0:4]
Y = df.values[:, 4]
trainX, testX, trainY, testY = train_test_split(X, Y, test_size=0.25)
decision.fit(trainX, trainY)
y_score = decision.score(testX, testY)
print('Accuracy: ', y_score)


# Compute the average precision score
from sklearn.metrics import average_precision_score
average_precision = average_precision_score(testY, y_score)

print('Average precision-recall score: {0:0.2f}'.format(
      average_precision))

And i have valueerror

File "C:/Users/Ultra/PycharmProjects/poker_ML/decision_tree.py", line 20, in <module>
    average_precision = average_precision_score(testY, y_score)
  File "C:\Users\Ultra\PycharmProjects\poker_ML\venv\lib\site-packages\sklearn\metrics\ranking.py", line 241, in average_precision_score
    average, sample_weight=sample_weight)
  File "C:\Users\Ultra\PycharmProjects\poker_ML\venv\lib\site-packages\sklearn\metrics\base.py", line 74, in _average_binary_score
    raise ValueError("{0} format is not supported".format(y_type))
ValueError: multiclass format is not supported

How can I compute precision-recall for 3 class? How works precision-recall for decision tree in sklearn. Maybe I have a mistake in calculation "y_score"?

According to scikit-learn docs average_precision_score cannot handle multiclass classification.

Instead, you may use precision_score like this:

# Decision tree
...
y_pred = decision.predict(testX)
y_score = decision.score(testX, testY)
print('Accuracy: ', y_score)

# Compute the average precision score
from sklearn.metrics import precision_score
micro_precision = precision_score(y_pred, testY, average='micro')
print('Micro-averaged precision score: {0:0.2f}'.format(
      micro_precision))

macro_precision = precision_score(y_pred, testY, average='macro')
print('Macro-averaged precision score: {0:0.2f}'.format(
      macro_precision))

per_class_precision = precision_score(y_pred, testY, average=None)
print('Per-class precision score:', per_class_precision)

Note, that you need to specify how to average the scores. This is especially relevant, if your dataset shows label imbalance (which iris does not).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM