简体   繁体   中英

Problem converting Keras model to Tensorflow-Lite using TocoConverter.from_keras_model_file

I'm facing an issue using TOCO to convert a Keras model to TfLite.

Followed the guide of: https://www.tensorflow.org/api_docs/python/tf/contrib/lite/TocoConverter

How I use TOCO:

def create_lite_model(keras_model_file):
    lite_model_name = 'lite_model_file.tflite'
    tf_lite_graph = os.path.join(WEIGHTS_DIRECTORY, lite_model_name)
    converter = tf.contrib.lite.TocoConverter.from_keras_model_file(keras_model_file)
    tf_lite_model = converter.convert()
    open(tf_lite_graph, "wb").write(tf_lite_model)

Getting the following error:

File "/tensorflow/contrib/lite/python/lite.py", line 356, in from_keras_model_file
keras_model = _keras.models.load_model(model_file)
File "/tensorflow/python/keras/engine/saving.py", line 251, in load_model
training_config['weighted_metrics'])
KeyError: 'weighted_metrics'

Does anybody has a solution for this problem?

Until now I didn't found a solution, but I'm using a workaround.

Converting the Keras model to a tf Graph using a SaveBuilder to store the tf Graph and finally using TocoConverter.from_saved_model(...).

import os

import tensorflow as tf
from keras import backend as K
from keras.models import load_model

K.set_learning_phase(False)


def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
    graph = session.graph
    with graph.as_default():
        freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))
        output_names = output_names or []
        output_names += [v.op.name for v in tf.global_variables()]
        input_graph_def = graph.as_graph_def()
        if clear_devices:
            for node in input_graph_def.node:
                node.device = ""
        frozen_graph = tf.graph_util.convert_variables_to_constants(session, input_graph_def, output_names,                                                               freeze_var_names)
        return frozen_graph


def create_lite_model_from_saved_model(saved_model_dir, tf_lite_path):
    converter = tf.contrib.lite.TocoConverter.from_saved_model(saved_model_dir)
    tf_lite_model = converter.convert()
    open(tf_lite_path, "wb").write(tf_lite_model)


def save_model(keras_model, session, pb_model_path):
    x = keras_model.input
    y = keras_model.output
    prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"inputs": x}, {"prediction": y})
    builder = tf.saved_model.builder.SavedModelBuilder(pb_model_path)
    legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
    signature = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: prediction_signature, }
    builder.add_meta_graph_and_variables(session, [tf.saved_model.tag_constants.SERVING], signature_def_map=signature,                                         legacy_init_op=legacy_init_op)
    builder.save()


def run():
    sess = K.get_session()
    keras_model_name = 'keras_model.h5'
    lite_model_name = 'lite_model_file.tflite'
    keras_model_file_path = os.path.join('./weights', keras_model_name)
    lite_model_file_path = os.path.join('./weights', lite_model_name)
    pb_model_path = os.path.join('./weights', 'saveBuilder')
    model = load_model(keras_model_file_path)
    output_names = [node.op.name for node in model.outputs]
    _ = freeze_session(sess, output_names=output_names)
    save_model(keras_model=model, session=sess, pb_model_path=pb_model_path)
    create_lite_model_from_saved_model(saved_model_dir=pb_model_path, tf_lite_path=lite_model_file_path)


if __name__ == "__main__":
    run()

Maybe it's helpful for someone.

The feeze_session(...) function I used is from: How to export Keras .h5 to tensorflow .pb?

如果这是您编写的模型,请确保在模型编译中定义了weighted_metrics:

model.compile(loss='binary_crossentropy', optimizer=<some_optimizer>, metrics=['accuracy'], weighted_metrics=['accuracy'])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM