[英]Tensorflow: preload multiple models
一般問題:您如何防止需要為每個推理請求重建模型?
我正在嘗試開發一個包含多個訓練有素的模型的網絡服務,這些模型可用於請求預測。 生成結果現在非常耗時,因為需要為每個請求重建模型。 推理本身只需要 30 毫秒,但導入模型需要一秒鍾以上。
由於需要會話,我很難將導入和推理分成兩個單獨的方法。
我想出的解決方案是使用存儲在變量中的InteractiveSession
。 在創建對象時,模型被加載到這個保持打開狀態的會話中。 當提交請求時,此預加載模型將用於生成結果。
此解決方案的問題:
為不同模型創建多個此對象時,會同時打開多個交互會話。 Tensorflow 生成以下警告:
Nesting violated for default stack of <class 'tensorflow.python.framework.ops.Graph'> objects
任何想法如何管理多個會話和預加載模型?
class model_inference:
def __init__(self, language_name, base_module="models"):
"""
Load a network that can be used to perform inference.
Args:
lang_class (str): The name of an importable language class,
returning an instance of `BaseLanguageModel`. This class
should be importable from `base_module`.
base_module (str): The module from which to import the
`language_name` class.
Attributes:
chkpt (str): The model checkpoint value.
infer_model (g2p_tensor.nmt.model_helper.InferModel):
The language infor_model instance.
"""
language_instance = getattr(
importlib.import_module(base_module), language_name
)()
self.ckpt = language_instance.checkpoint
self.infer_model = language_instance.infer_model
self.hparams = language_instance.hparams
self.rebuild_infer_model()
def rebuild_infer_model(self):
"""
recreate infer model after changing hparams
This is time consuming.
:return:
"""
self.session = tf.InteractiveSession(
graph=self.infer_model.graph, config=utils.get_config_proto()
)
self.model = model_helper.load_model(
self.infer_model.model, self.ckpt, self.session, "infer"
)
def infer_once(self, in_string):
"""
Entrypoint of service, should not contain rebuilding of the model.
"""
in_data = tokenize_input_string(in_string)
self.session.run(
self.infer_model.iterator.initializer,
feed_dict={
self.infer_model.src_placeholder: [in_data],
self.infer_model.batch_size_placeholder: self.hparams.infer_batch_size,
},
)
subword_option = self.hparams.subword_option
beam_width = self.hparams.beam_width
tgt_eos = self.hparams.eos
num_translations_per_input = self.hparams.num_translations_per_input
num_sentences = 0
num_translations_per_input = max(
min(num_translations_per_input, beam_width), 1
)
nmt_outputs, _ = self.model.decode(self.session)
if beam_width == 0:
nmt_outputs = np.expand_dims(nmt_outputs, 0)
batch_size = nmt_outputs.shape[1]
num_sentences += batch_size
for sent_id in range(batch_size):
for beam_id in range(num_translations_per_input):
translation = nmt_utils.get_translation(
nmt_outputs[beam_id],
sent_id,
tgt_eos=tgt_eos,
subword_option=subword_option,
)
return untokenize_output_string(translation.decode("utf-8"))
def __del__(self):
self.session.close()
def __exit__(self, exc_type, exc_val, exc_tb):
self.session.close()
在 jdehesa 評論的幫助下,我明白出了什么問題。
未指定需要使用哪個圖時。 Tensorflow 創建了一個新的圖實例並向其添加操作。 這就是為什么僅將InteractiveSession
更改為普通Session
以不嵌套交互式會話將引發新錯誤ValueError: Operation name: "init_all_tables" op: "NoOp" is not an element of this graph.
InteractiveSession
的使用有效,因為它將定義的圖形設置為默認使用,而不是創建新實例。 InteractiveSession
的問題在於同時打開多個會話是非常糟糕的。 Tensorflow 會發出警告。
解決方案如下:將InteractiveSession
更改為普通Session
您需要使用model_helper.load_model
明確定義要在哪個圖中重新加載模型。 這可以通過定義上下文來完成: with self.infer_model.graph.as_default():
最終的解決方案如下:
def rebuild_infer_model(self):
"""
recreate infer model after changing hparams
This is time consuming.
:return:
"""
self.session = tf.Session(
graph=self.infer_model.graph, config=utils.get_config_proto()
)
# added line:
with self.infer_model.graph.as_default(): # the model should be loaded within the same graph as when infering!!
model_helper.load_model(
self.infer_model.model, self.ckpt, self.session, "infer"
)
def infer_once(self, in_string):
"""
Turn an orthographic transcription into a phonetic transcription
The transcription is processed all at once
Long transcriptions may result in incomplete phonetic output
:param in_string: orthographic transcription
:return: string of the phonetic representation
"""
# added line:
with self.infer_model.graph.as_default():
in_data = tokenize_input_string(in_string)
self.session.run(
self.infer_model.iterator.initializer,
feed_dict={
self.infer_model.src_placeholder: [in_data],
self.infer_model.batch_size_placeholder: self.hparams.infer_batch_size,
},
)
subword_option = self.hparams.subword_option
beam_width = self.hparams.beam_width
tgt_eos = self.hparams.eos
num_translations_per_input = self.hparams.num_translations_per_input
num_sentences = 0
num_translations_per_input = max(
min(num_translations_per_input, beam_width), 1
)
nmt_outputs, _ = self.infer_model.model.decode(self.session)
if beam_width == 0:
nmt_outputs = np.expand_dims(nmt_outputs, 0)
batch_size = nmt_outputs.shape[1]
num_sentences += batch_size
for sent_id in range(batch_size):
for beam_id in range(num_translations_per_input):
translation = nmt_utils.get_translation(
nmt_outputs[beam_id],
sent_id,
tgt_eos=tgt_eos,
subword_option=subword_option,
)
return untokenize_output_string(translation.decode("utf-8"))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.