简体   繁体   中英

keras custom metric with sample weights

I am trying to define a custom metric in Keras that takes into account sample weights. When fitting the model I use the sample weights as follows:

training_history = model.fit(
        train_data,
        train_labels,
        sample_weight = train_weights,
        epochs = num_epochs,
        batch_size = 128,
        validation_data = (validation_data, validatation_labels, validation_weights ),
    )

An example of a custom metric I am using is the AUC (area under roc curve ), which I defined as follows:

from keras import backend as K
import tensorflow as tf

def auc(true_labels, predictions, weights = None):
    auc = tf.metrics.auc(true_labels, predictions, weights = weights)[1]
    K.get_session().run(tf.local_variables_initializer())
    return auc

and I use this metric when compiling the model:

model.compile(
        optimizer = optimizer,
        loss = 'binary_crossentropy',
        metrics = ['accuracy', auc]
    )

But as far as I can tell, the metric does not take into account the sample weights. In fact I verified this by comparing the metric value I see when training the model using the custom metric defined above to what I get by computing it myself from the model output and the sample weights, which indeed yield very different results. How would I define the auc metric shown above to take into account the sample weights?

You could wrap your metric with another function that takes sample_weights as an argument:

def auc(weights):
    def metric(true_labels, predictions):
        auc = tf.metrics.auc(true_labels, predictions, weights=weights)[1]
        K.get_session().run(tf.local_variables_initializer())
        return auc
    return metric

And then define an extra input placeholder that will receive the sample weights:

sample_weights = Input(shape=(1,))

Your model can then be compiled as follows:

model.compile(
    optimizer = optimizer,
    loss = 'binary_crossentropy',
    metrics = ['accuracy', auc(sample_weights)]
)

NOTE: Not tested.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM