繁体   English   中英

如何 plot x 和 y 在曲线的不同点处截取

[英]How to plot x and y intercepts at different points of a curve

我已经训练了一个二进制分类 model。

我能够在 model 的不同决策阈值下获得成对的精度和召回值,如下所示:

test_prob = model.predict_proba(test_x)[:, 1]
precisions, recalls, thresholds = precision_recall_curve(test_y, test_prob)

我可以使用matplotlib来 plot PR 曲线:

plt.plot(recalls, precisions, label=f"Chargbacks (AUC = {round(pr_auc, 2)})", c="b")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend()
plt.show()

这会产生这个 plot:

在此处输入图像描述

我还可以为不同的决策阈值创建相应精度和召回对的 dataframe,如下所示:

thresholds = pd.DataFrame(
   {
        "Threshold": thresholds, 
        "Precision": precisions[:-1], 
        "Recall": recalls[:-1]
   }
)

这会产生这个 dataframe:

     Threshold  Precision    Recall
0     0.000000   0.005016  1.000000
1     0.002222   0.056515  0.990991
2     0.010000   0.056555  0.990991
3     0.020000   0.113995  0.989189
4     0.030000   0.163076  0.981982
5     0.031667   0.203295  0.978378
6     0.031667   0.203371  0.978378
7     0.040000   0.203447  0.978378
8     0.050000   0.243341  0.971171
9     0.060000   0.282347  0.971171
10    0.070000   0.321128  0.963964
11    0.080000   0.355898  0.956757
12    0.090000   0.383883  0.944144
13    0.100000   0.405594  0.940541
14    0.110000   0.431063  0.935135
15    0.120000   0.460036  0.933333
16    0.130000   0.484082  0.931532
17    0.140000   0.508374  0.929730
18    0.150000   0.530864  0.929730
19    0.160000   0.550694  0.929730
20    0.170000   0.571109  0.918919
21    0.180000   0.587082  0.917117
22    0.190000   0.607914  0.913514
23    0.200000   0.622850  0.913514
24    0.210000   0.644955  0.909910
25    0.220000   0.653696  0.908108
26    0.230000   0.665779  0.900901
27    0.240000   0.680384  0.893694
28    0.250000   0.688456  0.891892
29    0.260000   0.698300  0.888288
30    0.270000   0.700855  0.886486
31    0.280000   0.706052  0.882883
32    0.290000   0.711790  0.881081
33    0.300000   0.719764  0.879279
34    0.310000   0.726727  0.872072
35    0.320000   0.730594  0.864865
36    0.330000   0.735069  0.864865
37    0.340000   0.744946  0.863063
38    0.350000   0.750392  0.861261
39    0.360000   0.756757  0.857658
40    0.370000   0.761218  0.855856
41    0.380000   0.766990  0.854054
42    0.390000   0.768852  0.845045
43    0.400000   0.777778  0.845045
44    0.410000   0.781513  0.837838
45    0.420000   0.787053  0.832432
46    0.430000   0.791096  0.832432
47    0.439630   0.792746  0.827027
48    0.440000   0.792388  0.825225
49    0.450000   0.793043  0.821622
50    0.460000   0.793345  0.816216
51    0.470000   0.799645  0.812613
52    0.480000   0.803220  0.809009
53    0.490000   0.805755  0.807207
54    0.500000   0.809872  0.798198
55    0.510000   0.809524  0.796396
56    0.520000   0.814815  0.792793
57    0.530000   0.819887  0.787387
58    0.540000   0.823864  0.783784
59    0.550000   0.825670  0.776577
60    0.560000   0.826590  0.772973
61    0.570000   0.828125  0.763964
62    0.580000   0.827789  0.762162
63    0.590000   0.832016  0.758559
64    0.600000   0.831349  0.754955
65    0.610000   0.832335  0.751351
66    0.620000   0.834694  0.736937
67    0.630000   0.836066  0.735135
68    0.640000   0.844075  0.731532
69    0.650000   0.845511  0.729730
70    0.660000   0.844211  0.722523
71    0.670000   0.846809  0.717117
72    0.680000   0.846482  0.715315
73    0.690000   0.850649  0.708108
74    0.700000   0.857768  0.706306
75    0.710000   0.863135  0.704505
76    0.720000   0.868889  0.704505
77    0.730000   0.876404  0.702703
78    0.740000   0.876147  0.688288
79    0.750000   0.875862  0.686486
80    0.760000   0.874126  0.675676
81    0.770000   0.874408  0.664865
82    0.780000   0.872596  0.654054
83    0.790000   0.882064  0.646847
84    0.800000   0.883085  0.639640
85    0.810000   0.887218  0.637838
86    0.820000   0.890585  0.630631
87    0.830000   0.890625  0.616216
88    0.840000   0.898396  0.605405
89    0.850000   0.898907  0.592793
90    0.860000   0.899441  0.580180
91    0.870000   0.901449  0.560360
92    0.880000   0.903904  0.542342
93    0.890000   0.907407  0.529730
94    0.900000   0.911672  0.520721
95    0.910000   0.912621  0.508108
96    0.920000   0.915541  0.488288
97    0.930000   0.916955  0.477477
98    0.940000   0.927536  0.461261
99    0.950000   0.932331  0.446847
100   0.960000   0.931174  0.414414
101   0.970000   0.939130  0.389189
102   0.980000   0.938095  0.354955
103   0.990000   0.935484  0.313514
104   1.000000   0.928058  0.232432

On the same plot as the PR Curve, I now want to plot horizontal dotted lines at y-values [0.1, 0.2, ..., 0.9] (the closest values, if available, to these in the dataframe above) that hit the蓝色曲线,然后垂直下降到 x 轴。 这些中的每一个都应标记为上面 dataframe 中的相应“阈值”。

我怎样才能做到这一点?

最终的 plot 应该如下所示:

在此处输入图像描述

编辑:

与其在每个precision = [0.1, ..., 0.9]处绘制截距,不如为每个threshold = [0.1, ..., 0.9] ,但同样的问题仍然存在于这个调整中.

idx = (np.abs(threshold - t)).argmin()threshold中找到最接近t的值的索引。 此索引可用于绘制线条和 position 文本。 可以类似地绘制给定精度的线。

import matplotlib.pyplot as plt
import numpy as np

threshold = np.array([0.0, 0.002222, 0.01, 0.02, 0.03, 0.031667, 0.031667, 0.04, 0.05, 0.06, 0.07, 0.08, 0.09, 0.1, 0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19, 0.2, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.3, 0.31, 0.32, 0.33, 0.34, 0.35, 0.36, 0.37, 0.38, 0.39, 0.4, 0.41, 0.42, 0.43, 0.43963, 0.44, 0.45, 0.46, 0.47, 0.48, 0.49, 0.5, 0.51, 0.52, 0.53, 0.54, 0.55, 0.56, 0.57, 0.58, 0.59, 0.6, 0.61, 0.62, 0.63, 0.64, 0.65, 0.66, 0.67, 0.68, 0.69, 0.7, 0.71, 0.72, 0.73, 0.74, 0.75, 0.76, 0.77, 0.78, 0.79, 0.8, 0.81, 0.82, 0.83, 0.84, 0.85, 0.86, 0.87, 0.88, 0.89, 0.9, 0.91, 0.92, 0.93, 0.94, 0.95, 0.96, 0.97, 0.98, 0.99, 1.0])
precisions =  np.array([0.005016, 0.056515, 0.056555, 0.113995, 0.163076, 0.203295, 0.203371, 0.203447, 0.243341, 0.282347, 0.321128, 0.355898, 0.383883, 0.405594, 0.431063, 0.460036, 0.484082, 0.508374, 0.530864, 0.550694, 0.571109, 0.587082, 0.607914, 0.62285, 0.644955, 0.653696, 0.665779, 0.680384, 0.688456, 0.6983, 0.700855, 0.706052, 0.71179, 0.719764, 0.726727, 0.730594, 0.735069, 0.744946, 0.750392, 0.756757, 0.761218, 0.76699, 0.768852, 0.777778, 0.781513, 0.787053, 0.791096, 0.792746, 0.792388, 0.793043, 0.793345, 0.799645, 0.80322, 0.805755, 0.809872, 0.809524, 0.814815, 0.819887, 0.823864, 0.82567, 0.82659, 0.828125, 0.827789, 0.832016, 0.831349, 0.832335, 0.834694, 0.836066, 0.844075, 0.845511, 0.844211, 0.846809, 0.846482, 0.850649, 0.857768, 0.863135, 0.868889, 0.876404, 0.876147, 0.875862, 0.874126, 0.874408, 0.872596, 0.882064, 0.883085, 0.887218, 0.890585, 0.890625, 0.898396, 0.898907, 0.899441, 0.901449, 0.903904, 0.907407, 0.911672, 0.912621, 0.915541, 0.916955, 0.927536, 0.932331 , 0.931174, 0.93913, 0.938095, 0.935484, 0.928058])
recalls = np.array([1.0, 0.990991, 0.990991, 0.989189, 0.981982, 0.978378, 0.978378, 0.978378, 0.971171, 0.971171, 0.963964, 0.956757, 0.944144, 0.940541, 0.935135, 0.933333, 0.931532, 0.92973, 0.92973, 0.92973, 0.918919, 0.917117, 0.913514, 0.913514, 0.90991, 0.908108, 0.900901, 0.8936940, 0.891892, 0.888288, 0.886486, 0.882883, 0.881081, 0.879279, 0.872072, 0.864865, 0.864865, 0.863063, 0.861261, 0.857658, 0.855856, 0.854054, 0.845045, 0.845045, 0.837838, 0.832432, 0.832432, 0.8270270000000001, 0.825225, 0.821622, 0.816216, 0.812613, 0.809009, 0.807207, 0.798198, 0.796396, 0.792793, 0.787387, 0.783784, 0.776577, 0.772973, 0.763964, 0.762162, 0.758559, 0.754955, 0.751351, 0.736937, 0.735135, 0.731532, 0.72973, 0.722523, 0.717117, 0.715315, 0.708108, 0.706306, 0.704505, 0.704505, 0.702703, 0.688288, 0.686486, 0.675676, 0.664865, 0.654054, 0.646847, 0.63964, 0.637838, 0.630631, 0.616216, 0.605405, 0.592793, 0.58018, 0.56036, 0.542342, 0.52973, 0.520721, 0.508108, 0.488288, 0.477477, 0.461261, 0.446847, 0.414414, 0.389189, 0.354955, 0.313514, 0.232432])

fig, axs = plt.subplots(ncols=2, figsize=(10, 4))

for ax in axs:
    ax.plot(recalls, precisions, label=f"Chargbacks (AUC = {round(0.85, 2)})", c="b")

    if ax == axs[0]:
        for p in np.arange(0.1, 1, 0.1):
            idx = (np.abs(precisions - p)).argmin()
            ax.plot([recalls[idx], recalls[idx], 0], [0, precisions[idx], precisions[idx]], c='crimson')
            ax.text(0.02,precisions[idx], t, color='crimson', fontsize=10, va='bottom', ha='left' )
    else:
        for i in range(1, 10):
            t = i * 0.1
            idx = (np.abs(threshold - t)).argmin()
            ax.plot([recalls[idx], recalls[idx], 0], [0, precisions[idx], precisions[idx]], c='crimson')
            ax.text(0.02 if i % 2 == 1 else 0.07, precisions[idx], threshold[idx], color='black', fontsize=10, va='bottom', ha='left' )

    ax.set_xlim(xmin=0)
    ax.set_ylim(ymin=0)

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")
    ax.legend()
plt.show()

示例图

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM