簡體   English   中英

在 scikit-learn 中使用 DecisionTreeClassifier 修復 100% 准確率

[英]Fixing the 100% accuracy with DecisionTreeClassifier in scikit-learn

我正在嘗試使用決策樹進行分類並獲得 100% 的准確率。

這是一個常見問題, 在此處此處進行了描述。 在許多其他問題上。

數據在這里

兩個最佳猜測:

  • 我錯誤地拆分數據
  • 我的數據集太不平衡了

我的代碼有什么問題?

import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
from sklearn.model_selection import cross_val_score
import sklearn.model_selection as cv
from sklearn.metrics import mean_squared_error as MSE
from sklearn.model_selection import train_test_split 
from sklearn import metrics
from sklearn.metrics import confusion_matrix 
from sklearn.metrics import accuracy_score 

# Split data
Y = starbucks.iloc[:, 4]
X = starbucks.loc[:, starbucks.columns != 'offer_completed']

# Splitting the dataset into train and test 
X_train, X_test, y_train, y_test = train_test_split(X, Y, 
                                                    test_size=0.3,
                                                    random_state=100) 

# Creating the classifier object 
clf_gini = DecisionTreeClassifier(criterion = "gini", 
                                  random_state = 100, 
                                  max_depth = 3, 
                                  min_samples_leaf = 5) 

# Performing training 
clf_gini.fit(X_train, y_train)

# Predicton on test with giniIndex 
y_pred = clf_gini.predict(X_test) 
print("Predicted values:") 
print(y_pred) 

print("Confusion Matrix: ", confusion_matrix(y_test, y_pred)) 

print ("Accuracy : ", accuracy_score(y_test, y_pred)*100) 

print("Report : ", classification_report(y_test, y_pred)) 

y_pred_gini = prediction(X_test, clf_gini) 
cal_accuracy(y_test, y_pred_gini) 


Predicted values:
[0. 0. 0. ... 0. 0. 0.]
Confusion Matrix:  [[36095     0]
                    [    0  8158]]
Accuracy :  100.0

當我打印 X 時,它顯示offer_completed已被刪除。

X.dtypes

offer_received               int64
offer_viewed               float64
time_viewed_received       float64
time_completed_received    float64
time_completed_viewed      float64
transaction                float64
amount                     float64
total_reward               float64
age                        float64
income                     float64
male                         int64
membership_days            float64
reward_each_time           float64
difficulty                 float64
duration                   float64
email                      float64
mobile                     float64
social                     float64
web                        float64
bogo                       float64
discount                   float64
informational              float64

安裝 model 並檢查特征重要性,您可以看到除了total_reward之外它們全為零。 然后投資這樣的專欄,你會得到:

df.groupby(target)['total_reward'].describe()
    count   mean    std    min   25%    50%   75%    max
0   119995  0.0     0.0    0.0   0.0    0.0   0.0    0.0
1   27513   5.74    4.07   2.0   3.0    5.0   10.0   40.0

您可以看到,對於目標 0, total_reward始終為零,否則其值始終大於 0。這是您的泄漏。

由於可能存在其他泄漏並且檢查每一列很乏味,我們可以單獨使用每個特征的一種“預測能力”:

acc_df = pd.DataFrame(columns=['col', 'acc'], index=range(len(X.columns)))

for i, c in enumerate(X.columns):

    clf = DecisionTreeClassifier(criterion = "gini", 
                                 random_state = 100, 
                                 max_depth = 3, 
                                 min_samples_leaf = 5) 
    
    clf.fit(X_train[c].to_numpy()[:, None], y_train)
    
    y_pred = clf.predict(X_test[c].to_numpy()[:, None])
    acc_df.iloc[i] = [c, accuracy_score(y_test, y_pred)*100]


acc_df.sort_values('acc',ascending=False)
                 col      acc
8       total_reward      100
4     completed_time  99.8848
13  reward_each_time  89.3205
14        difficulty  89.3205
15          duration  89.3205
21          discount  86.4054
19               web   85.088
20              bogo  84.4801
3        viewed_time  84.4056
2       offer_viewed  84.3491
18            social  83.3525
1      received_time  83.0497
7             amount  82.5436
0     offer_received  81.7526
16             email  81.7526
17            mobile  81.6464
11              male  81.5651
10            income  81.5651
9                age  81.5651
6   transaction_time  81.5651
5        transaction  81.5651
22     informational  81.5651
12   membership_days  81.5561

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM