簡體   English   中英

為什么我會收到 best_precision_threshold 的 IndexError?

[英]Why am I getting an IndexError for best_precision_threshold?

我在另一台機器上運行以下代碼沒有問題,但是,當我嘗試在另一台機器上運行它時,我遇到以下錯誤:

class_names = ['Fish', 'Flower', 'Sugar', 'Gravel']

def get_threshold_for_recall(y_true, y_pred, class_i, recall_threshold=0.94, precision_threshold=0.90, plot=False):

    precision, recall, thresholds = precision_recall_curve(y_true[:, class_i], y_pred[:, class_i])
    i = len(thresholds) - 1
    best_recall_threshold = None
    while best_recall_threshold is None:
        next_threshold = thresholds[i]
        next_recall = recall[i]
        if next_recall >= recall_threshold:
            best_recall_threshold = next_threshold
        i -= 1

    # consice, even though unnecessary passing through all the values
    best_precision_threshold = [thres for prec, thres in zip(precision, thresholds) if prec >= precision_threshold][0]

    if plot:
        plt.figure(figsize=(10, 7))
        plt.step(recall, precision, color='r', alpha=0.3, where='post')
        plt.fill_between(recall, precision, alpha=0.3, color='r')
        plt.axhline(y=precision[i + 1])
        recall_for_prec_thres = [rec for rec, thres in zip(recall, thresholds) 
                                 if thres == best_precision_threshold][0]
        plt.axvline(x=recall_for_prec_thres, color='g')
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.ylim([0.0, 1.05])
        plt.xlim([0.0, 1.0])
        plt.legend(['PR curve', 
                    f'Precision {precision[i + 1]: .2f} corresponding to selected recall threshold',
                    f'Recall {recall_for_prec_thres: .2f} corresponding to selected precision threshold'])
        plt.title(f'Precision-Recall curve for Class {class_names[class_i]}')
    return best_recall_threshold, best_precision_threshold


y_pred = model.predict_generator(data_generator_val, workers=num_cores)
y_true = data_generator_val.get_labels()
recall_thresholds = dict()
precision_thresholds = dict()

for i, class_name in tqdm(enumerate(class_names)):
    recall_thresholds[class_name], precision_thresholds[class_name] = get_threshold_for_recall(y_true, y_pred, i, plot=True)

我希望這四個類有四個精確召回曲線,但是,我收到以下錯誤消息:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-79-422044a4f5da> in <module>
     37 precision_thresholds = dict()
     38 for i, class_name in tqdm(enumerate(class_names)):
---> 39     recall_thresholds[class_name], precision_thresholds[class_name] = get_threshold_for_recall(y_true, y_pred, i, plot=True)

<ipython-input-79-422044a4f5da> in get_threshold_for_recall(y_true, y_pred, class_i, recall_threshold, precision_threshold, plot)
     12 
     13     # consice, even though unnecessary passing through all the values
---> 14     best_precision_threshold = [thres for prec, thres in zip(precision, thresholds) if prec > precision_threshold][0]
     15 
     16     if plot:

IndexError: list index out of range

找出我在上面的代碼片段中遇到的問題。 請確保您擁有正確的版本:tensorflow==1.14.0 和 keras=2.3.0

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM