简体   繁体   English

ValueError:分类指标无法在 ROC 曲线计算中处理多类和多标签指标目标的混合

[英]ValueError: Classification metrics can't handle a mix of multiclass and multilabel-indicator targets in ROC curve calculation

I'm trying to draw a roc curve for multiclass classification.我正在尝试为多类分类绘制 roc 曲线。

At first I calculate y_pred and y_proba using the following code首先,我使用以下代码计算y_predy_proba

X_train, X_test, y_train, y_test = train_test_split(X, Y, random_state = 0)
  
# training a DescisionTreeClassifier
from sklearn.tree import DecisionTreeClassifier

dtree_model = DecisionTreeClassifier(max_depth = 2).fit(X_train, y_train)

y_pred = dtree_model.predict(X_test)
y_proba= dtree_model.predict_proba(X_test)

After that I use the following function to calculate tpr and fpr之后我使用下面的tpr来计算 tpr 和fpr

from sklearn.metrics import confusion_matrix

def calculate_tpr_fpr(y_test, y_pred):
    '''
    Calculates the True Positive Rate (tpr) and the True Negative Rate (fpr) based on real and predicted observations
    
    Args:
     y_real: The list or series with the real classes
     y_pred: The list or series with the predicted classes
    
    Returns:
     tpr: The True Positive Rate of the classifier
     fpr: The False Positive Rate of the classifier
    '''
    
    # Calculates the confusion matrix and recover each element
    cm = confusion_matrix(y_test, y_pred)
    TN = cm[0, 0]
    FP = cm[0, 1]
    FN = cm[1, 0]
    TP = cm[1, 1]
    
    # Calculates tpr and fpr
    tpr = TP / (TP + FN) # sensitivity - true positive rate
    fpr = 1 - TN / (TN + FP) # 1-specificity - false positive rate
    
    return tpr, fpr

Then, I try using this function to calculate a list of fpr and tpr to draw the curve然后,我尝试使用这个tpr来计算fpr和 tpr 的列表来绘制曲线

def get_all_roc_coordinates(y_test, y_proba):
    '''
    Calculates all the ROC Curve coordinates (tpr and fpr) by considering each point as a treshold for the predicion of the class.
    
    Args:
     y_test: The list or series with the real classes.
     y_proba: The array with the probabilities for each class, obtained by using the `.predict_proba()` method.
     
    Returns:
     tpr_list: The list of TPRs representing each threshold.
     fpr_list: The list of FPRs representing each threshold.
    '''
    
    tpr_list = [0]
    fpr_list = [0]
    
    for i in range(len(y_proba)):
        threshold = y_proba[i]
        y_pred = y_proba = threshold
        tpr, fpr = calculate_tpr_fpr(y_test, y_pred)
        tpr_list.append(tpr)
        fpr_list.append(fpr)
        
    return tpr_list, fpr_list

but it gives me the following error但它给了我以下错误

ValueError: Classification metrics can't handle a mix of multiclass and multilabel-indicator targets

Note that the Y column is multiclass {0,1,2}.请注意,Y 列是多类 {0,1,2}。 I also tried to ensure that y is string not integer, but it gives me the same error.我还尝试确保 y 不是 integer 的字符串,但它给了我同样的错误。

You've got 3 classes but you only use 2 classes in your calculate_tpr_fpr() .您有 3 个类,但在calculate_tpr_fpr()中只使用 2 个类。 Also, you probably meant y_pred = y_proba > threshold .另外,您的意思可能是y_pred = y_proba > threshold Either way, it won't be that easy since you've got 3 columns of class scores.无论哪种方式,它都不会那么容易,因为你有 3 列 class 分数。 The easiest way seems to be drawing one vs rest curves, treating each column individually:最简单的方法似乎是绘制一条与 rest 曲线,分别处理每一列:

from sklearn.metrics import roc_curve
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt

classes = range(y_proba.shape[1])

for i in classes:
    fpr, tpr, _ = roc_curve(label_binarize(y_test, classes=classes)[:,i], y_proba[:,i])
    plt.plot(fpr, tpr, alpha=0.7)
    plt.legend(classes)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何处理 ValueError:分类指标无法处理多标签指标和多类目标错误的混合 - how to handle ValueError: Classification metrics can't handle a mix of multilabel-indicator and multiclass targets error 如何修复 ValueError:分类指标无法处理模型的多类和多标签指标目标的混合? - How to fix ValueError: Classification metrics can't handle a mix of multiclass and multilabel-indicator targets for model? ValueError:分类指标无法处理多类和多标记指标目标的混合 - ValueError: Classification metrics can't handle a mix of multiclass and multilabel-indicator targets 错误:分类指标无法处理多类多输出和多标记指标目标的混合 - Error: Classification metrics can't handle a mix of multiclass-multioutput and multilabel-indicator targets 混淆矩阵错误“分类指标无法处理多标签指标和多类目标的混合” - confusion matrix error "Classification metrics can't handle a mix of multilabel-indicator and multiclass targets" 分类指标无法处理多类和多标签指标目标的混合 - Classification metrics can't handle a mix of multiclass and multilabel-indicator targets ValueError:分类指标无法处理多标签指标和连续多输出目标错误的混合 - ValueError: Classification metrics can't handle a mix of multilabel-indicator and continuous-multioutput targets error ValueError:分类指标无法处理多标签指标和连续多输出目标的混合 - ValueError: Classification metrics can't handle a mix of multilabel-indicator and continuous-multioutput targets ValueError:分类指标无法处理多标签指标和连续多输出目标 sklearn 的混合 - ValueError: Classification metrics can't handle a mix of multilabel-indicator and continuous-multioutput targets sklearn ValueError:分类指标无法处理多标签指标和二进制目标的混合 - ValueError: Classification metrics can't handle a mix of multilabel-indicator and binary targets
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM