简体   繁体   中英

Load custom loss with extra input in keras

I have a custom loss function that takes the input to the model as one of the arguments. If I load in the same session in which I train, I can load it no problem using this technique.


def custom_loss(inputs):
    def loss(y_true, y_pred):
        return ...
    return loss

inputs = keras.layers.Input(shape=(...))
y = keras.layers.Activation('tanh')(inputs)

model = keras.models.Model(inputs=inputs, outputs=y)

model.compile(loss=custom_loss(inputs), optimizer='Adam')
model.fit(...)
model.save('mymodel.h5')
load_model('mymodel.h5', custom_objects={'custom_loss': custom_loss(inputs})

However, I run into problems when I try to load the model in a later session, because this time I don't have access to the original input tensor. If I make a new inputs placeholder, then the model expects two different sets of inputs and I error out.

inputs = keras.layers.Input(shape=(...))
load_model('mymodel.h5', custom_objects={'custom_loss': custom_loss(inputs)})

Is there a good way to solve this problem? The issue at the end of the day is that the inputs haven't been deserialized yet so they can't be passed in to the custom objects.

I don't want to just save the weights and create a new model with the same weights because I lose the optimizer state.

An alternate way is to compute the loss inside a Keras layer and pass a dummy loss function that just returns the model's output as the loss in the compile method. There are other ways of doing this. But this is the one that I prefer.

import tensorflow as tf
print('Tensorflow', tf.__version__)

def custom_loss(tensor):
    y_true, y_pred, inputs = tensor[0], tensor[1], tensor[1]
    loss = ...
    return tf.constant([0], dtype=tf.float32)

def dummy_loss(y_true, y_pred):
    return y_pred

def get_model(training=False):
    inputs = tf.keras.layers.Input(shape=(10,))
    y = tf.keras.layers.Activation('tanh')(inputs)
    if training:
        targets = tf.keras.layers.Input(shape=(10,)) 
        loss_layer = tf.keras.layers.Lambda(custom_loss)([targets, y, inputs])
        model = tf.keras.models.Model(inputs=[inputs, targets], outputs=loss_layer)
    else:
        model = tf.keras.models.Model(inputs=inputs, outputs=y)
    return model


model = get_model(training=True)
model.compile(optimizer='sgd', loss=dummy_loss)
model.save('model.h5')

new_model = tf.keras.models.load_model('model.h5', custom_objects={'dummy_loss':dummy_loss})

这对我有用。

load_model('mymodel.h5', custom_objects={'custom_loss': custom_loss(inputs)}, compile=False)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM