简体   繁体   English

如何在 Tensorflow2.x 中按子类 tf.keras.losses.Loss class 自定义损失

[英]How to custom losses by subclass tf.keras.losses.Loss class in Tensorflow2.x

When I read the guides in the websites of Tensorflow, I find two ways to custom losses.当我阅读 Tensorflow 网站上的指南时,我发现了两种自定义损失的方法。 The first one is to define a loss function,just like:第一个是定义一个损失 function,就像:

def basic_loss_function(y_true, y_pred):
    return tf.math.reduce_mean(tf.abs(y_true - y_pred))

And for the sake of simplicity, we assume the batch size is also 1, so the shape of y_true and y_pred are both (1, c), where c is the number of classes.而且为了简单起见,我们假设batch size也是1,所以y_truey_pred的shape都是(1, c),其中c是类数。 So in this method, we give two vectors y_true and y_pred , and return a value(scala).所以在这个方法中,我们给出两个向量y_truey_pred ,并返回一个值(scala)。

Then, the second method is to subclass tf.keras.losses.Loss class, and the code in guide is:然后,第二种方法是子类tf.keras.losses.Loss class,guide中的代码是:

class WeightedBinaryCrossEntropy(keras.losses.Loss):
    """
    Args:
      pos_weight: Scalar to affect the positive labels of the loss function.
      weight: Scalar to affect the entirety of the loss function.
      from_logits: Whether to compute loss from logits or the probability.
      reduction: Type of tf.keras.losses.Reduction to apply to loss.
      name: Name of the loss function.
    """
    def __init__(self, pos_weight, weight, from_logits=False,
                 reduction=keras.losses.Reduction.AUTO,
                 name='weighted_binary_crossentropy'):
        super().__init__(reduction=reduction, name=name)
        self.pos_weight = pos_weight
        self.weight = weight
        self.from_logits = from_logits

    def call(self, y_true, y_pred):
        ce = tf.losses.binary_crossentropy(
            y_true, y_pred, from_logits=self.from_logits)[:,None]
        ce = self.weight * (ce*(1-y_true) + self.pos_weight*ce*(y_true))
        return ce

In the call method, as usual, we give two vectors y_true and y_pred , but I notice that it return ce , which is a VECTOR with shape (1, c) !!!在 call 方法中,像往常一样,我们给出了两个向量y_truey_pred ,但我注意到它返回ce ,它是一个形状为 (1, c) 的向量!

So is there any problem in the above toy example?那么上面的玩具例子有什么问题吗? Or Tensorflow2.x has some magic behind that?或者 Tensorflow2.x 背后有什么魔力?

The main difference between the two aside from implementation is the type of the loss functions.除了实现之外,两者之间的主要区别在于损失函数的类型。 The first one is L1 loss (average of absolute differences by definition, used for mostly regression like problems), while the second is binary crossentropy (used for classification).第一个是 L1 损失(定义的绝对差异的平均值,主要用于类似回归的问题),而第二个是二元交叉熵(用于分类)。 They are not meant to be different implementations of the same loss, and this is stated in the guide you linked.它们并不意味着相同损失的不同实现,这在您链接的指南中有所说明。

Binary crossentropy in a multi-label, multi-class classification setting outputs a value for every class, as if they were independent of each other.多标签、多类分类设置中的二元交叉熵为每个 class 输出一个值,就好像它们彼此独立一样。

Edit:编辑:

In the second loss function the reduction parameter controls the way the output is aggregated, eg.在第二个损失 function 中, reduction参数控制 output 聚合的方式,例如。 taking the sum of elements or summing over the batch etc. By default, your code uses keras.losses.Reduction.AUTO , which translates into summing over the batch if you check the source code .取元素的总和或对批次求和等。默认情况下,您的代码使用keras.losses.Reduction.AUTO ,如果您检查源代码,则转换为对批次求和。 This means, the final loss will be a vector, but there are other reductions available, you can check them in the docs .这意味着,最终损失将是一个向量,但还有其他可用的缩减,您可以在docs中查看它们。 I believe even if you do not define the reduction to take the sum of the loss elements in the loss vector, TF optimizers will do so, to avoid errors from backpropagating a vector.我相信即使你没有定义减少来获取损失向量中损失元素的总和,TF 优化器也会这样做,以避免反向传播向量的错误。 Backpropagation on a vector would cause problems at weights that "contribute" to every loss element.向量上的反向传播会导致“有助于”每个损失元素的权重出现问题。 However, I have not checked this in the source code.但是,我没有在源代码中检查这一点。 :) :)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM