简体   繁体   中英

How to save&load a keras model with a custom loss function which depends on class variables?

I'm creating a LSTM based Variational Auto Encoder. I have created my own custom loss function where I am using two different class variables, pos_weight and kl_cof . kl_cof is changing it's value every epoch.

 class LSTM_VAE(object):
     
     def __init__(..,pos_weight=None)
     
         ..
         self.kl_cof = K.variable(0.5)
         self.pos_weight = pos_weight
         self.autoencoder = None
         ....

         def vae_loss(x, x_decoded_mean):
            xent_loss = K.mean(x*-K.log(x_decoded_mean)*self.pos_weight+(1-x)*-K.log(1-x_decoded_mean))
            kl_loss = ....
            return xent_loss + K.get_value(self.kl_cof)*kl_loss
            ....
          self.autoencoder.compile(optimizer='Adam',
                     loss=vae_loss,
                     metrics=['accuracy'])

    def train_model(self,X_train,X_val,batch_size,epochs):
          cbk = LambdaCallback(....)
          print_weights = LambdaCallback(....)
          self.autoencoder.fit(x=X_train, y=X_train,batch_size=batch_size,epochs=epochs,callbacks=[self.checkpointer,cbk,print_weights],validation_data=(X_val,X_val))
    ....

When self.autoencoder has finished training, the model is saved as a 'lstm_vae.h5' file. But when I try to run

  lstm_vae=keras.models.load_model('lstm_vae.h5')

it throws an ValueError: Unknown loss function:vae_loss .

After doing some research in other threads, it says that I should pass the loss function inside custom_object as such

 lstm_vae=keras.models.load_model('lstm_vae.h5',custom_object={'loss':vae_loss})

But once again it throws ValueError: Unknown loss function:vae_loss since I have the loss function defined inside the LSTM_VAE class, this is because the loss function is using LSTM_VAE class variables.

So how do I correctly load the model when I am using this custom loss function?

Try this:

model = load_model("path/to/model.h5", compile=False)
optimizer = Adam(learning_rate=.01)

def vae_loss(x, x_decoded_mean):
    xent_loss = K.mean(x*-K.log(x_decoded_mean)*self.pos_weight+(1-x)*-   K.log(1-x_decoded_mean))
    kl_loss = ....
    return xent_loss + K.get_value(self.kl_cof)*kl_loss

model.compile(optimizer=optimizer, loss=vae_loss, metrics=['accuracy'])

I had the same issue and loading the model uncompiled and then defining the loss before recompiling was the solution that worked for me

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM