繁体   English   中英

你如何有效地处理张量的维度?

[英]How do you efficiently work with the dimensions of tensors?

我正在尝试为我的机器学习 model 编写自定义损失 function。 具体来说,我试图减少误报。 这是我的第一次尝试:

def sens_spec(y_true, y_pred, sens_weight, spec_weight):
    FN = tf.math.maximum(0, y_pred[:][0]-y_true[:][0])
    FP = tf.math.maximum(0, y_pred[:][1]-y_true[:][1])
    TN = tf.math.minimum(y_pred[:][0],y_true[:][0])
    TP = tf.math.minimum(y_pred[:][1],y_true[:][1])

    FN = tf.math.reduce_sum(FN)
    FP = tf.math.reduce_sum(FP)
    TN = tf.math.reduce_sum(TN)
    TP = tf.math.reduce_sum(TP)
    
    sensitivity = TP / (TP + FN + K.epsilon())
    specificity = TN / (TN + FP + K.epsilon())
    
    return tf.math.subtract(1.0, (sens_weight*sensitivity, spec_weight*specificity))
    
    
def custom_loss(sens_weight, spec_weight):    
    def spec_loss(y_true, y_pred):
        return sens_spec(y_true, y_pred, sens_weight, spec_weight)
    return spec_loss    

因此,例如,假设这是我们的输入:

y_true = [[0, 1], [1, 0]] # ["positive", "negative"]
y_pred = [[0, 1], [0, 1]] # ["positive", "positive"]
sens_weight = 0.1
spec_weight = 0.9

那么结果将是:

FN == 0 # 0 False Negatives
FP == 1 # 1 False Positive
TN == 0 # 0 True Negatives
TP == 1 # 1 True Positive

sensitivity == 1
specificity == 0

loss == 0.9

它确实有效,但速度很慢。 我怀疑这是因为切片效率低下。 model 输出“负”为 [1, 0] 和“正”为 [0, 1]。 因此,y_true 和 y_pred 的形状是 (batchsize, 2),我必须访问最后一个维度来计算损失 function。

如何有效地访问和应用数学到张量的最后一维?

如果您的标签和预测是一个元素而不是两个元素,那会简单得多。 如果您正在进行二进制分类,则可以从第一列推导出第二列(反之亦然),从而不需要索引(并可能加快计算速度)。


import tensorflow as tf


y_true = tf.constant([[0], [1]], tf.float32)
y_pred = tf.constant([[0], [0]], tf.float32)

sens_w = tf.constant(0.1)
spec_w = tf.constant(0.9)

FN = tf.reduce_sum(tf.maximum(0, y_pred - y_true))
FP = tf.reduce_sum(tf.maximum(0, y_true - y_pred))
TN = tf.reduce_sum(tf.minimum(y_pred, y_pred)
TP = tf.reduce_sum(tf.minimum(1-y_pred, 1-y_true))

sens = TP / (TP + FN + 1e-10)
spec = TN / (TN + FP + 1e-10)

tf.subtract(1.0, (sens_w * sens, spec_w * spec))
# <tf.Tensor: shape=(2,), dtype=float32, numpy=array([0.9, 1. ], dtype=float32)>

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM