简体   繁体   中英

Load keras model with custom_metrics and custom loss

I have created a keras model by sub classing keras.model. I have also used custom loss (focal loss), custom metrics (sub classing the keras.metrics) and learning rate decay. I have trained the model and saved it using tf.keras.callbacks.ModelCheckpoint(model_path) .

When I try to load the model, I get an error which says ValueError: Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements get_config and from_config when saving. In addition, please use the custom_objects arg when calling load_model() ValueError: Unable to restore custom object of type _tf_keras_metric currently. Please make sure that the layer implements get_config and from_config when saving. In addition, please use the custom_objects arg when calling load_model()

After digging about the error I came to know about passing the custom_objects. However after reading about it and trying few things, I am still not able to load the model. Could someone let me know the correct way of doing it. My codes are as follows:

def get_metrics():
    train_accuracy = tf.keras.metrics.CategoricalAccuracy(name="train_accuracy")
    val_accuracy = tf.keras.metrics.CategoricalAccuracy(name="val_accuracy")
    confusion_matrix = ConfusionMatrixMetric(20)
    return confusion_matrix, train_accuracy, val_accuracy


def loss_fn(labels, logits):

    epsilon = 1e-9
    model_out = tf.nn.softmax(logits, axis=-1) + epsilon
    ce = - (tf.math.log(model_out) * labels)
    weight = labels * tf.math.pow(1 - model_out, gamma)
    fl = alpha * weight * ce
    loss = tf.reduce_max(fl, axis=-1)
    return loss


def get_optimizer(steps_per_epoch, finetune=False):
    lr = 0.001
    if finetune:
        lr = 0.00001
    lr_fn = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        [steps_per_epoch * 10], [lr, lr / 10], name=None
    )
    opt_op = tf.keras.optimizers.Adam(learning_rate=lr_fn)
    return opt_op
        
class MyModel(keras.Model):
    def compile(self, optimizer, loss_fn, metric_fn):
        super(MyModel, self).compile()
        self.optimizer = optimizer
        self.loss_fn = loss_fn
        self.confusion_matrix, self.train_accuracy, self.val_accuracy = metric_fn()

    def train_step(self, train_data):
        X, y = train_data
        with tf.GradientTape() as tape:
            logits = self(X, training=True)
            loss = self.loss_fn(y, logits)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        # compute metrics keeping an moving average
        y_pred = tf.nn.softmax(y, axis=-1)

        self.train_accuracy.update_state(y, y_pred )
        self.confusion_matrix.update_state(y, y_pred)

        update_dict = {"train_accuracy": self.train_accuracy.result()}
        if 'confusion_matrix_metric' in self.metrics_names:
            self.metrics[0].add_results(update_dict)
        return update_dict
    
class ConfusionMatrixMetric(tf.keras.metrics.Metric):
    def __init__(self, num_classes, **kwargs):
        super(ConfusionMatrixMetric, self).__init__(name='confusion_matrix_metric', **kwargs)  # handles base args (e.g., dtype)
        self.num_classes = num_classes
        self.total_cm = self.add_weight("total", shape=(num_classes, num_classes), initializer="zeros")

    def reset_states(self):
        for s in self.variables:
            s.assign(tf.zeros(shape=s.shape))

    def update_state(self, y_true, y_pred, sample_weight=None):
        self.total_cm.assign_add(self.confusion_matrix(y_true, y_pred))
        return self.total_cm

    def result(self):
        return self.process_confusion_matrix()

    def confusion_matrix(self, y_true, y_pred):
        y_pred = tf.math.argmax(y_pred, 1)
        cm = tf.math.confusion_matrix(y_true, y_pred, dtype=tf.float32, num_classes=self.num_classes)
        return cm

    def process_confusion_matrix(self):
        cm = self.total_cm
        diag_part = tf.linalg.diag_part(cm)
        # accuracy = tf.math.reduce_sum(diag_part) / (tf.math.reduce_sum(cm) + tf.constant(1e-15))
        precision = diag_part / (tf.math.reduce_sum(cm, 0) + tf.constant(1e-15))
        recall = diag_part / (tf.math.reduce_sum(cm, 1) + tf.constant(1e-15))
        f1 = 2 * precision * recall / (precision + recall + tf.constant(1e-15))
        return f1

    def add_results(self, output):
        results = self.result()
        for i in range(self.num_classes):
            output['F1_{}'.format(i)] = results[i]

if __name__ == "__main__":
    model_path = 'model/my_custom_model/'
    create_folder(model_path)
    callbacks = [tf.keras.callbacks.ModelCheckpoint(model_path)]
    # train
    model = MyModel(inputs, outputs)
    model.summary()
    opt_op = get_optimizer(100)

    model.compile(optimizer=opt_op,
                  loss_fn=loss_fn,
                  metric_fn=get_metrics)

    model.fit(train_data_gen(),
              epochs=10,
              callbacks=callbacks)

    tf.keras.models.load_model(model_path)

Sorry for the long code. But just wanted to make sure that whatever I am doing is correct and understandable.

As your error notes, you should implement the get_config method if you want to subclass, use and load a custom metric.

You have build your metric subclassing correctly the class tf.Keras.metrics.Metric , you only need to add the get_config and get your parameters with it (from what I see, you have only num_classes ):

def get_config(self):
    base_config = super().get_config()
    return {**base_config, "num_classes": self.num_classes}

Also, when you load, you must load also your custom metric:

tf.keras.models.load_model(model_path, custom_objects={"ConfusionMatrixMetric": ConfusionMatrixMetric )

Beware of the following though (from the book Hands-On Machine Learning with Scikit-Learn and TensorFlow by Aurélien Géron, 2nd Edition ):

The Keras API currently only specifies how to use subclassing to define layers models, callbacks and regularizers. If you build other components (such as losses, metrics, initializers, or constraints). using subclassing, they may not be portable to other Keras implementations. It is likely that the Keras API will be updated to specify subclassing for all these components as well.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM