I'm creating a LSTM based Variational Auto Encoder. I have created my own custom loss function where I am using two different class variables, pos_weight
and kl_cof
. kl_cof
is changing it's value every epoch.
class LSTM_VAE(object):
def __init__(..,pos_weight=None)
..
self.kl_cof = K.variable(0.5)
self.pos_weight = pos_weight
self.autoencoder = None
....
def vae_loss(x, x_decoded_mean):
xent_loss = K.mean(x*-K.log(x_decoded_mean)*self.pos_weight+(1-x)*-K.log(1-x_decoded_mean))
kl_loss = ....
return xent_loss + K.get_value(self.kl_cof)*kl_loss
....
self.autoencoder.compile(optimizer='Adam',
loss=vae_loss,
metrics=['accuracy'])
def train_model(self,X_train,X_val,batch_size,epochs):
cbk = LambdaCallback(....)
print_weights = LambdaCallback(....)
self.autoencoder.fit(x=X_train, y=X_train,batch_size=batch_size,epochs=epochs,callbacks=[self.checkpointer,cbk,print_weights],validation_data=(X_val,X_val))
....
When self.autoencoder has finished training, the model is saved as a 'lstm_vae.h5' file. But when I try to run
lstm_vae=keras.models.load_model('lstm_vae.h5')
it throws an ValueError: Unknown loss function:vae_loss
.
After doing some research in other threads, it says that I should pass the loss function inside custom_object as such
lstm_vae=keras.models.load_model('lstm_vae.h5',custom_object={'loss':vae_loss})
But once again it throws ValueError: Unknown loss function:vae_loss
since I have the loss function defined inside the LSTM_VAE class, this is because the loss function is using LSTM_VAE class variables.
So how do I correctly load the model when I am using this custom loss function?
Try this:
model = load_model("path/to/model.h5", compile=False)
optimizer = Adam(learning_rate=.01)
def vae_loss(x, x_decoded_mean):
xent_loss = K.mean(x*-K.log(x_decoded_mean)*self.pos_weight+(1-x)*- K.log(1-x_decoded_mean))
kl_loss = ....
return xent_loss + K.get_value(self.kl_cof)*kl_loss
model.compile(optimizer=optimizer, loss=vae_loss, metrics=['accuracy'])
I had the same issue and loading the model uncompiled and then defining the loss before recompiling was the solution that worked for me
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.