简体   繁体   中英

Keras: How to load a model having two outputs and a custom loss function?

I have trained a Keras (with Tensorflow backend) model which has two outputs with a custom loss function. I need help in loading the model from disk using the custom_objects argument.

When compiling the model I have used the loss and loss_weights argument as follows:

losses = {
            'output_layer_1':custom_loss_fn,
            'output_layer_2':custom_loss_fn
         }

loss_weights = {
                'output_layer_1': 1.0, 
                'output_layer_2': 1.0
               }

model.compile(loss=losses, loss_weights=loss_weights, optimizer=opt)

The model is training without any problems. I save the model as follows:

model.save(model_path)

The reason I haven't defined "custom_loss_fn" here is because custom_loss_fn is defined inside another custom Keras layer.

My question is how do I load the model which is persisted to disk during inference. If it was a single ouput model I would load the model using custom_objects as described in this stackoverflow question: Loading model with custom loss + keras

model = keras.models.load_model(model_path, custom_objects={'custom_loss_fn':custom_loss_fn})

But how to extend this in my case where I have two outputs with the losses and loss weights defined in a dictionary along with a custom loss function?

In other words, how should custom_objects be populated in this case where losses and loss_weights are defined as dictionaries?

I'm using Keras v2.1.6 with Tensorflow backend v1.8.0.

If you can recompile the model on the loading side, the easiest way is to save just the weights: model.save_weights() . If you want to use save_model and have custom Keras layers, be sure they implement the get_config method (see this reference). As for the ops without gradient, I have seen this while mixing tensorflow and Keras without using properly the keras.backend functions, but I can't help any more without the model code itself.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM