简体   繁体   中英

Plotting Threshold (precision_recall curve) matplotlib/sklearn.metrics

I am trying to plot the thresholds for my precision/recall curve. I am just using the MNSIT data, with the example from the book Hands On Machine Learning with scikit-learn, keras, and TensorFlow. Trying to train the model to detect the image of 5's. I don't know how much of the code you need to see. I have made my confusion matrix for the training set and have calculated the precision and recall values, along with the thresholds. I have plotted the pre/rec curve and the example in the book says to add axis label, ledged, grid and highlight the thresholds but the code cuts off in the book where I placed an asterisk below. I was able to figure out all but how to get the thresholds to show up on the plot. I have included a picture of what the graph in the book looks like vs what I have. This is what the book shows:

在此处输入图像描述 vs my graph:

在此处输入图像描述

I can't get that red dotline with two threshold points to show up. Does anyone have any idea how I would do this? Here is my code below:

from sklearn.metrics import precision_recall_curve

precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)

def plot_precision_recall_vs_thresholds(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
    plt.plot(thresholds, recalls[:-1], "g--", label="Recall")
    plt.xlabel("Threshold")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    plt.grid(b=True, which="both", axis="both", color='gray', linestyle='-', linewidth=1)

plot_precision_recall_vs_thresholds(precisions, recalls, thresholds)
plt.show()

I know there's a decent amount of questions on here about this with sklearn but none seem to cover getting that red line to show up. I would greatly appreciate the help!

You can use the following code for plotting horizontal and vertical lines:

plt.axhline(y_value, c='r', ls=':')
plt.axvline(x_value, c='r', ls=':')

This should work in the exact way:

def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    recall_80_precision = recalls[np.argmax(precisions >= 0.80)]
    threshold_80_precision = thresholds[np.argmax(precisions >= 0.80)]
    
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
    plt.xlabel("Threshold")
    plt.plot([threshold_80_precision, threshold_80_precision], [0., 0.8], "r:")
    plt.axis([-4, 4, 0, 1])
    plt.plot([-4, threshold_80_precision], [0.8, 0.8], "r:")
    plt.plot([-4, threshold_80_precision], [recall_80_precision, recall_80_precision], "r:")
    plt.plot([threshold_80_precision], [0.8], "ro") 
    plt.plot([threshold_80_precision], [recall_80_precision], "ro")
    plt.grid(True)
    plt.legend()
    plt.show()

I came across this code in my attempt to replicate the code in this book. Turns out @ageron placed all of the resources on his github page. You could check it out here

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM