简体   繁体   中英

Custom loss function on subset of outputs based on inputs

I want to create a loss function where the MSE is only calculated on a subset of the outputs. The subset depends on the input data. I used the answer to this question to figure out how to create a custom function based on the input data:

Custom loss function in Keras based on the input data

However, I'm having trouble implementing the custom function to work.

Here is what I've put together.

def custom_loss(input_tensor):


    def loss(y_true, y_pred):
        board = input_tensor[:81]
        answer_vector = board == .5
        #assert np.sum(answer_vector) > 0

        return K.mean(K.square(y_pred * answer_vector - y_true), axis=-1)
    return loss


def build_model(input_size, output_size):
    learning_rate = .001
    a = Input(shape=(input_size,))
    b = Dense(60, activation='relu')(a)
    b = Dense(60, activation='relu')(b)
    b = Dense(60, activation='relu')(b)
    b = Dense(output_size, activation='linear')(b)
    model = Model(inputs=a, outputs=b)
    model.compile(loss=custom_loss(a), optimizer=Adam(lr=learning_rate))

    return model

model = build_model(83, 81)

I want the MSE to treat the output as 0 wherever the board is not equal to 0.5. (The true value is one hot encoded with the one being within the subset). For some reason my output my output is treated as always zero. In other words, the custom loss function doesn't seem to be finding any places where the board is equal to 0.5.

I can't tell if I'm misinterpretting the dimensions or if the comparisons are failing due to the tensors, or even if there is just a generally much easier approach to do what I'm trying.

The problem is that answer_vector = board == .5 is not what you think it is. It is not a tensor, but the boolean value False, since board is a tensor and 0.5 is a number:

a = tf.constant([0.5, 0.5])
print(a == 0.5) # False

Now, a * False is a vector fo zeros:

with tf.Session() as sess:
   print(sess.run(a * False)) # [0.0, 0.0]

You need to use tf.equal instead of ==. Another possible pitfall is that comparing floats with equality is dangerous, see eg What's wrong with using == to compare floats in Java?

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM