繁体   English   中英

将 TensorFlow 损失全局目标 (recall_at_precision_loss) 与 Keras(非指标)一起使用

[英]Use TensorFlow loss Global Objectives (recall_at_precision_loss) with Keras (not metrics)

背景

我有一个带有 5 个标签的多标签分类问题(例如[1 0 1 1 0] )。 因此,我希望我的模型在固定召回、精确召回 AUC 或 ROC AUC 等指标上有所改进。

使用与我想要优化的性能测量没有直接关系的损失函数(例如binary_crossentropy )是没有意义的。 因此,我想使用 TensorFlow 的global_objectives.recall_at_precision_loss()或类似的损失函数。

不是公制的

我不是在寻找实施tf.metrics 我已经在以下方面取得了成功: https : //stackoverflow.com/a/50566908/3399066

问题

我认为我的问题可以分为两个问题:

  1. 如何使用global_objectives.recall_at_precision_loss()或类似的?
  2. 如何在带有 TF 后端的 Keras 模型中使用它?

问题一

全局目标 GitHub 页面上有一个名为loss_layers_example.py的文件(同上) 但是,由于我对TF没有太多经验,所以我不太了解如何使用它。 此外,谷歌搜索TensorFlow recall_at_precision_loss exampleTensorFlow Global objectives example不会给我任何更清晰的示例。

如何在简单的 TF 示例中使用global_objectives.recall_at_precision_loss()

问题二

像(在model.compile(loss = ??.recall_at_precision_loss, ...)中): model.compile(loss = ??.recall_at_precision_loss, ...)就足够了吗? 我的感觉告诉我它比这更复杂,因为使用了loss_layers_example.py使用的全局变量。

如何在global_objectives.recall_at_precision_loss()使用类似于global_objectives.recall_at_precision_loss()损失函数?

我设法通过以下方式使其工作:

  • 显式地将张量整形为 BATCH_SIZE 长度(见下面的代码)
  • 将数据集大小削减为 BATCH_SIZE 的倍数
    def precision_recall_auc_loss(y_true, y_pred):
        y_true = keras.backend.reshape(y_true, (BATCH_SIZE, 1)) 
        y_pred = keras.backend.reshape(y_pred, (BATCH_SIZE, 1))   
        util.get_num_labels = lambda labels : 1
        return loss_layers.precision_recall_auc_loss(y_true, y_pred)[0]

类似于 Martino 的答案,但会从输入推断形状(将其设置为固定的批量大小对我不起作用)。

外部函数并不是绝对必要的,但是在配置损失函数时传递参数感觉更自然一些,尤其是当您的包装器在外部模块中定义时。

import keras.backend as K
from global_objectives.loss_layers import precision_at_recall_loss

def get_precision_at_recall_loss(target_recall): 
    def precision_at_recall_loss_wrapper(y_true, y_pred):
        y_true = K.reshape(y_true, (-1, 1)) 
        y_pred = K.reshape(y_pred, (-1, 1))   
        return precision_at_recall_loss(y_true, y_pred, target_recall)[0]
    return precision_at_recall_loss_wrapper

然后,在编译模型时:

TARGET_RECALL = 0.9
model.compile(optimizer='adam', loss=get_precision_at_recall_loss(TARGET_RECALL))

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM