简体   繁体   中英

(tf.)keras loading saved model weights with trainable word embeddings

I'm having a trouble with loading model weights in (tf.)Keras.

My model is just a simple LSTM model with a pre-trained word embedding, but I left the word embedding to be trainable while training.

I saved model weights with the following code:

mc = ModelCheckpoint(filepath, save_weights_only=True, monitor='val_accuracy', mode='max', verbose=1, save_best_only=True)

I checked that there exists the hdf5 file at the filepath, with a size of around 18MB.

Later, I tried to load the weights with the following code:

model = build_model() #the function that I used to make the model in Training process
model = model.load_weights(filepath)

However, model.load_weights(filepath) returns None

Question1. Is there any problem with these codes? If not is this possibly because I left the word embedding to be trainable?

Question2. = In this case, where is the modified word embedding saved? Is it saved with other parameters in the hdf5 file? If this is the case how can I load this fine-tuned word embedding?

To extract the word embedding you need to first extract the embedding layer out of the required model

embed_layer = model.get_layer('embedding_26') #embedding_26 is generated name of embedding layer

Extract trained word embeddings

embed_layer.get_weights()

>>> [array([[ 9.0566e-01, -7.1792e-01, -1.9574e-01, ...,  1.1230e-03,
          2.8188e-02,  3.0385e-01],
        [ 5.8560e-01, -3.6964e-01,  6.3480e-02, ...,  5.6656e-01,
         -3.6404e-01, -2.5202e-01],
        [ 4.5269e-01, -6.2509e-01,  1.6866e-01, ..., -5.0146e-01,
          2.9764e-01,  1.4548e-01],
        ...,
        [-1.0632e-01,  6.8057e-01, -1.5388e+00, ..., -4.8493e-01,
          3.2478e-01, -1.1330e-01],
        [ 7.6822e-01,  7.1786e-01,  5.8778e-01, ...,  1.6097e-01,
          8.9411e-02,  8.4237e-01],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]], dtype=float32)]

I am not sure if you can directly load weights from a file, but here is something that you could do:

model = load_model('best_model.h5')
weights = model.get_weights()  # load weights of a model

You can then use to load this in another model of the same architecture

model2.set_weights(weights)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM