简体   繁体   English

具有来自拟合生成器的多个参数的 Keras 自定义损失函数

[英]Keras custom loss function with multiple arguments from fit generator

I want to create a custom loss which gets the output of the net and multiple arguments from a data generator.我想创建一个自定义损失,它从数据生成器获取网络输出和多个参数。

I found this article, which describes how to calculate one loss from multiple layers with one label.我找到了这篇文章,它描述了如何使用一个标签计算多层的一个损失。 But I want to calculate the loss from a single layer with multiple labels using the fit_generator.但我想使用 fit_generator 计算具有多个标签的单层损失。 My problem is that Keras expects the output and the label to be of the same shape.我的问题是 Keras 期望输出和标签具有相同的形状。

example:例子:

Regular custom loss:常规自定义损失:

def custom_loss(y_pred, y_label):
        return K.mean(y_pred - y_label)

An example for the type of custom loss I want to use:我想使用的自定义损失类型的示例:

def custom_loss(y_pred, y_label, y_weights):
     loss = K.mean(y_pred - y_label)
     return tf.compat.v1.losses.compute_weighted_loss(loss, y_weights)

This is just an example my original code is a little more complicated.这只是一个示例,我的原始代码稍微复杂一些。 I just want to be able to give the loss function two parameters (y_label and y_weights) instead of only one (y_label).我只想能够给损失函数两个参数(y_label 和 y_weights),而不是只有一个(y_label)。

Does anyone know how to solve this problem?有谁知道如何解决这个问题?

I am not sure what exactly you are asking but maybe you can use this.我不确定你到底在问什么,但也许你可以使用它。 You can try something like a custom function that returns a loss function.您可以尝试类似返回损失函数的自定义函数。

def custom_loss(y_weights):

    # Create a loss function that calculates what you want
    def example_loss(y_true,y_pred):
        loss = K.mean(y_pred - y_label)
        return tf.compat.v1.losses.compute_weighted_loss(loss, y_weights)

    # Return a function
    return example_loss

# Compile the model
model.compile(optimizer='adam',
          loss=custom_loss(y_weights), # Call the loss function with the preferred weights
          metrics=['accuracy'])

You can also take a look at this question你也可以看看这个问题

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM