简体   繁体   English

基于输入的输出子集上的自定义损失函数

[英]Custom loss function on subset of outputs based on inputs

I want to create a loss function where the MSE is only calculated on a subset of the outputs.我想创建一个损失函数,其中 MSE 仅在输出的子集上计算。 The subset depends on the input data.子集取决于输入数据。 I used the answer to this question to figure out how to create a custom function based on the input data:我使用这个问题的答案来弄清楚如何根据输入数据创建自定义函数:

Custom loss function in Keras based on the input data Keras 中基于输入数据的自定义损失函数

However, I'm having trouble implementing the custom function to work.但是,我在实现自定义功能时遇到了麻烦。

Here is what I've put together.这是我整理的内容。

def custom_loss(input_tensor):


    def loss(y_true, y_pred):
        board = input_tensor[:81]
        answer_vector = board == .5
        #assert np.sum(answer_vector) > 0

        return K.mean(K.square(y_pred * answer_vector - y_true), axis=-1)
    return loss


def build_model(input_size, output_size):
    learning_rate = .001
    a = Input(shape=(input_size,))
    b = Dense(60, activation='relu')(a)
    b = Dense(60, activation='relu')(b)
    b = Dense(60, activation='relu')(b)
    b = Dense(output_size, activation='linear')(b)
    model = Model(inputs=a, outputs=b)
    model.compile(loss=custom_loss(a), optimizer=Adam(lr=learning_rate))

    return model

model = build_model(83, 81)

I want the MSE to treat the output as 0 wherever the board is not equal to 0.5.我希望 MSE 在电路板不等于 0.5 的任何地方将输出视为 0。 (The true value is one hot encoded with the one being within the subset). (真正的值是一个热编码,一个在子集中)。 For some reason my output my output is treated as always zero.出于某种原因,我的输出始终被视为零。 In other words, the custom loss function doesn't seem to be finding any places where the board is equal to 0.5.换句话说,自定义损失函数似乎没有找到任何棋盘等于 0.5 的地方。

I can't tell if I'm misinterpretting the dimensions or if the comparisons are failing due to the tensors, or even if there is just a generally much easier approach to do what I'm trying.我不知道我是否误解了维度,或者比较是否由于张量而失败,或者即使只有一种通常更简单的方法来做我正在尝试的事情。

The problem is that answer_vector = board == .5 is not what you think it is.问题是answer_vector = board == .5不是你想的那样。 It is not a tensor, but the boolean value False, since board is a tensor and 0.5 is a number:它不是张量,而是布尔值 False,因为 board 是张量而 0.5 是数字:

a = tf.constant([0.5, 0.5])
print(a == 0.5) # False

Now, a * False is a vector fo zeros:现在, a * False是一个零向量:

with tf.Session() as sess:
   print(sess.run(a * False)) # [0.0, 0.0]

You need to use tf.equal instead of ==.您需要使用 tf.equal 而不是 ==。 Another possible pitfall is that comparing floats with equality is dangerous, see eg What's wrong with using == to compare floats in Java?另一个可能的陷阱是将浮点数与相等进行比较是危险的,请参阅例如在 Java 中使用 == 比较浮点数有什么问题?

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM