简体   繁体   English

Tensorflow:预加载多个模型

[英]Tensorflow: preload multiple models

General question: How can you prevent that a model needs to be rebuild for each inference request?一般问题:您如何防止需要为每个推理请求重建模型?

I'm trying to develop a web-service that contains multiple trained models which can be used to request a prediction.我正在尝试开发一个包含多个训练有素的模型的网络服务,这些模型可用于请求预测。 Producing a results is now very time consuming because the model needs to be rebuild for each request.生成结果现在非常耗时,因为需要为每个请求重建模型。 The inferring itself only takes 30ms but importing the model takes more than a second.推理本身只需要 30 毫秒,但导入模型需要一秒钟以上。
I'm having difficulty splitting the importing and inference into two separate methods because of the needed session.由于需要会话,我很难将导入和推理分成两个单独的方法。

The solution i came up with is by using an InteractiveSession that is stored in a variable.我想出的解决方案是使用存储在变量中的InteractiveSession On creation of the object the model gets loaded inside of this session that remains open.在创建对象时,模型被加载到这个保持打开状态的会话中。 When a request is submitted this preloaded model is than used to generate the result.当提交请求时,此预加载模型将用于生成结果。

Problem with this solution:此解决方案的问题:
When creating multiple of this objects for different models, multiple Interactive sessions are open at the same time.为不同模型创建多个此对象时,会同时打开多个交互会话。 Tensorflow generate the following warning: Tensorflow 生成以下警告:

Nesting violated for default stack of <class 'tensorflow.python.framework.ops.Graph'> objects

Any ideas how to manage multiple sessions and preload models?任何想法如何管理多个会话和预加载模型?

class model_inference:
    def __init__(self, language_name, base_module="models"):
        """
        Load a network that can be used to perform inference.

        Args:

            lang_class (str): The name of an importable language class,
                returning an instance of `BaseLanguageModel`. This class
                should be importable from `base_module`.

            base_module (str):  The module from which to import the
                `language_name` class.

        Attributes:

            chkpt (str): The model checkpoint value.
            infer_model (g2p_tensor.nmt.model_helper.InferModel):
                The language infor_model instance.
        """

        language_instance = getattr(
            importlib.import_module(base_module), language_name
        )()
        self.ckpt = language_instance.checkpoint
        self.infer_model = language_instance.infer_model
        self.hparams = language_instance.hparams
        self.rebuild_infer_model()

    def rebuild_infer_model(self):
        """
        recreate infer model after changing hparams
        This is time consuming.
        :return:
        """
        self.session = tf.InteractiveSession(
            graph=self.infer_model.graph, config=utils.get_config_proto()
        )
        self.model = model_helper.load_model(
            self.infer_model.model, self.ckpt, self.session, "infer"
        )

    def infer_once(self, in_string):
        """
        Entrypoint of service, should not contain rebuilding of the model.
        """
        in_data = tokenize_input_string(in_string)

        self.session.run(
            self.infer_model.iterator.initializer,
            feed_dict={
                self.infer_model.src_placeholder: [in_data],
                self.infer_model.batch_size_placeholder: self.hparams.infer_batch_size,
            },
        )

        subword_option = self.hparams.subword_option
        beam_width = self.hparams.beam_width
        tgt_eos = self.hparams.eos
        num_translations_per_input = self.hparams.num_translations_per_input

        num_sentences = 0

        num_translations_per_input = max(
            min(num_translations_per_input, beam_width), 1
        )

        nmt_outputs, _ = self.model.decode(self.session)
        if beam_width == 0:
            nmt_outputs = np.expand_dims(nmt_outputs, 0)

        batch_size = nmt_outputs.shape[1]
        num_sentences += batch_size

        for sent_id in range(batch_size):
            for beam_id in range(num_translations_per_input):
                translation = nmt_utils.get_translation(
                    nmt_outputs[beam_id],
                    sent_id,
                    tgt_eos=tgt_eos,
                    subword_option=subword_option,
                )
        return untokenize_output_string(translation.decode("utf-8"))

    def __del__(self):
        self.session.close()

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.session.close()

With the help of jdehesa's comments i understood what went wrong.在 jdehesa 评论的帮助下,我明白出了什么问题。
When not specifying which graph needs to be used.未指定需要使用哪个图时。 Tensorflow makes a new instance of a graph and adds the operations to it. Tensorflow 创建了一个新的图实例并向其添加操作。 That's why just changing the InteractiveSession to a normal Session to not nest interactive sessions will throw a new error ValueError: Operation name: "init_all_tables" op: "NoOp" is not an element of this graph.这就是为什么仅将InteractiveSession更改为普通Session以不嵌套交互式会话将引发新错误ValueError: Operation name: "init_all_tables" op: "NoOp" is not an element of this graph.

The use of a InteractiveSession worked because it sets the defined graph to be used as default in stead of creating a new instance. InteractiveSession的使用有效,因为它将定义的图形设置为默认使用,而不是创建新实例。 The problem with the InteractiveSession is that its very bad to leave multiple sessions open at the same time. InteractiveSession的问题在于同时打开多个会话是非常糟糕的。 Tensorflow will throw a warning. Tensorflow 会发出警告。

The solution was the following: When changing the InteractiveSession to a normal Session you need to explicitly define in which graph you want to reload the model with model_helper.load_model .解决方案如下:将InteractiveSession更改为普通Session您需要使用model_helper.load_model明确定义要在哪个图中重新加载模型。 This can be done by defining a context: with self.infer_model.graph.as_default():这可以通过定义上下文来完成: with self.infer_model.graph.as_default():

The eventual solution was the following:最终的解决方案如下:

def rebuild_infer_model(self):
    """
    recreate infer model after changing hparams
    This is time consuming.
    :return:
    """
    self.session = tf.Session(
        graph=self.infer_model.graph, config=utils.get_config_proto()
    )
    # added line:
    with self.infer_model.graph.as_default(): # the model should be loaded within the same graph as when infering!!
        model_helper.load_model(
            self.infer_model.model, self.ckpt, self.session, "infer"
        )

def infer_once(self, in_string):
    """
    Turn an orthographic transcription into a phonetic transcription
    The transcription is processed all at once
    Long transcriptions may result in incomplete phonetic output
    :param in_string: orthographic transcription
    :return: string of the phonetic representation
    """
    # added line:
    with self.infer_model.graph.as_default():
        in_data = tokenize_input_string(in_string)

        self.session.run(
            self.infer_model.iterator.initializer,
            feed_dict={
                self.infer_model.src_placeholder: [in_data],
                self.infer_model.batch_size_placeholder: self.hparams.infer_batch_size,
            },
        )

        subword_option = self.hparams.subword_option
        beam_width = self.hparams.beam_width
        tgt_eos = self.hparams.eos
        num_translations_per_input = self.hparams.num_translations_per_input

        num_sentences = 0

        num_translations_per_input = max(
            min(num_translations_per_input, beam_width), 1
        )

        nmt_outputs, _ = self.infer_model.model.decode(self.session)
        if beam_width == 0:
            nmt_outputs = np.expand_dims(nmt_outputs, 0)

        batch_size = nmt_outputs.shape[1]
        num_sentences += batch_size

        for sent_id in range(batch_size):
            for beam_id in range(num_translations_per_input):
                translation = nmt_utils.get_translation(
                    nmt_outputs[beam_id],
                    sent_id,
                    tgt_eos=tgt_eos,
                    subword_option=subword_option,
                )
    return untokenize_output_string(translation.decode("utf-8"))

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM