简体   繁体   中英

Save keras model weights directly to bytes/memory?

Keras allows for saving entire models or just model weights (see thread ). When saving the weights, they must be saved to a file, eg:

model = keras_model()
model.save_weights('/tmp/model.h5')

Instead of writing to file, I'd like to just save the bytes into memory. Something like

model.dump_weights()

Tensorflow doesn't seem to have this, so as a workaround I'm writing to disk and then reading into memory:

temp = '/tmp/weights.h5'
model.save_weights(temp)
with open(temp, 'rb') as f:
    weightbytes = f.read()

Any way to avoid this roundabout?

weights=model.get_weights() will get the model weights. model.set_weights(weights) will set the model weights.One of the issues though is WHEN do you save the model weights. Generally you want to save the model weights for the epoch in which you had the lowest validation loss. The Keras callback ModelCheckpoint will save the weights with the lowest validation loss to a file. I found that saving to a file is inconvenient so I wrote a small custom callback to just save the weight with the lowest validation loss into a class variable then after training is complete load those weights into the model to make predictions. Code is shown below. Just add save_best_weights to the list of callbacks when you compile the model.

class save_best_weights(tf.keras.callbacks.Callback):
best_weights=model.get_weights()    
def __init__(self):
    super(save_best_weights, self).__init__()
    self.best = np.Inf
def on_epoch_end(self, epoch, logs=None):
    current_loss = logs.get('val_loss')
    accuracy=logs.get('val_accuracy')* 100
    if np.less(current_loss, self.best):
        self.best = current_loss            
        save_best_weights.best_weights=model.get_weights()
        print('\nSaving weights validation loss= {0:6.4f}  validation accuracy= {1:6.3f} %\n'.format(current_loss, accuracy))   

Convert model to json, and use dill dump, then store the bytes file, you can use base64 to store to database if needed, save model weights as well, all happen in memory, no touching disk

from io import BytesIO
import dill,base64,tempfile

#Saving Model as base64
model_json = Keras_model.to_json()

def Base64Converter(ObjectFile):
    bytes_container = BytesIO()
    dill.dump(ObjectFile, bytes_container)
    bytes_container.seek(0)
    bytes_file = bytes_container.read()
    base64File = base64.b64encode(bytes_file)
    return base64File

base64KModelJson = Base64Converter(model_json)  
base64KModelJsonWeights = Base64Converter(Keras_model.get_weights())  

for loading back, use model_from_json, joblib and tempfile

#Loading Back
from joblib import load
from keras.models import model_from_json
def ObjectConverter(base64_File):
    loaded_binary = base64.b64decode(base64_File)
    loaded_object = tempfile.TemporaryFile()
    loaded_object.write(loaded_binary)
    loaded_object.seek(0)
    ObjectFile = load(loaded_object)
    loaded_object.close()
    return ObjectFile

modeljson = ObjectConverter(base64KModelJson)
modelweights = ObjectConverter(base64KModelJsonWeights)
loaded_model = model_from_json(modeljson)
loaded_model.set_weights(modelweights)

If you look into keras model saving mechanism in source code: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/engine/training.py , the valid lines for saving in h5 format are:

if save_format == 'h5':
    with h5py.File(filepath, 'w') as f:
        hdf5_format.save_weights_to_hdf5_group(f, self.layers)

Therefore, you can directly create an in-memory h5 file and save model weights in it:

import io
from tensorflow.python.keras.saving import hdf5_format

bytes_file = io.BytesIO()
with h5py.File(bytes_file, 'w') as f:
    hdf5_format.save_weights_to_hdf5_group(f, self.layers)
weight_bytes = bytes_file.getvalue()

Thanks @ddoGas for pointing out the model.get_weights() method, which returns a list of weights that can then be serialized. Just some context for why I am not saving the model in the conventional way: we are working with model wrapper classes that associate a model and custom behavior. For example, before prediction occurs special validation is needed:

class CNN:
   ...
   def predict():
       self.do_special_validation()
       self.model.predict()

Hence, we're serializing the CNN class not just the underlying model. This is the solution to pickle the entire object. ( pickle(CNN()) fails, otherwise we'd just use that)

import pickle

def serialize(cnn):
    return pickle.dumps({
        "weights": cnn.model.get_weights(),
        "cnnclass": cnn.__class__
    })

def deserialize(cnn_bytes):
    loaded = pickle.loads(cnn_bytes)
    weights, cnnclass = loaded['weights'], loaded['cnnclass']
    cnninstance = cnnclass()
    cnninstance.model.set_weights(weights)
    return cnninstance

Works well, thanks!

PS note using cnn.__class__ because don't want to necessarily bind this to the CNN class directly but for it to work in general for any class that has a cnn.model attribute.

I wanted to use the code from the answer from Gerry P in an own module wich didnt just work like that so I made a few changes. Here some infos on what I did:

  • Moved that code into file/module named topmodelbox.py
  • Added the required imports
  • Initialized best_weights with None because there is no (simple) access to model at this point and nothing necsessary anyways (in my case)
  • Removed the accuracy part as that is not available for my (and many other) loss functions
  • Some information on how to use the class:

Contents of topmodelbox.py

import numpy as np
import tensorflow as tf

class cb_hold_best_weights(tf.keras.callbacks.Callback):
    best_weights = []
    def __init__(self):
        super(cb_hold_best_weights, self).__init__()
        self.best = np.Inf
    def on_epoch_end(self, epoch, logs=None):
        current_loss = logs.get('val_loss')
        if np.less(current_loss, self.best):
            self.best = current_loss
            cb_hold_best_weights.best_weights = self.model.get_weights()
            print('\nSaving weights validation loss= {0:6.4f}\n'.format(current_loss))

This can be simply be used after import topmodelbox by adding it to the list of callbacks like so:

callbacks=[topmodelbox.cb_hold_best_weights()]

In a function like model.fit for example.

Later we can use

model.set_weights(topmodelbox.cb_hold_best_weights.best_weights) 

to load the stored weights.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM