I'm having a trouble with loading model weights in (tf.)Keras.
My model is just a simple LSTM model with a pre-trained word embedding, but I left the word embedding to be trainable while training.
I saved model weights with the following code:
mc = ModelCheckpoint(filepath, save_weights_only=True, monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)
I checked that there exists the hdf5 file at the filepath, with a size of around 18MB.
Later, I tried to load the weights with the following code:
model = build_model() #the function that I used to make the model in Training process
model = model.load_weights(filepath)
However, model.load_weights(filepath) returns None
Question1. Is there any problem with these codes? If not is this possibly because I left the word embedding to be trainable?
Question2. = In this case, where is the modified word embedding saved? Is it saved with other parameters in the hdf5 file? If this is the case how can I load this fine-tuned word embedding?
To extract the word embedding you need to first extract the embedding layer out of the required model
embed_layer = model.get_layer('embedding_26') #embedding_26 is generated name of embedding layer
Extract trained word embeddings
embed_layer.get_weights()
>>> [array([[ 9.0566e-01, -7.1792e-01, -1.9574e-01, ..., 1.1230e-03,
2.8188e-02, 3.0385e-01],
[ 5.8560e-01, -3.6964e-01, 6.3480e-02, ..., 5.6656e-01,
-3.6404e-01, -2.5202e-01],
[ 4.5269e-01, -6.2509e-01, 1.6866e-01, ..., -5.0146e-01,
2.9764e-01, 1.4548e-01],
...,
[-1.0632e-01, 6.8057e-01, -1.5388e+00, ..., -4.8493e-01,
3.2478e-01, -1.1330e-01],
[ 7.6822e-01, 7.1786e-01, 5.8778e-01, ..., 1.6097e-01,
8.9411e-02, 8.4237e-01],
[ 0.0000e+00, 0.0000e+00, 0.0000e+00, ..., 0.0000e+00,
0.0000e+00, 0.0000e+00]], dtype=float32)]
I am not sure if you can directly load weights from a file, but here is something that you could do:
model = load_model('best_model.h5')
weights = model.get_weights() # load weights of a model
You can then use to load this in another model of the same architecture
model2.set_weights(weights)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.