繁体   English   中英

如何用 tensorflow 2 中的适当代码替换 tensorflow 1.x 中的 tf.graph 和 tf.session?

[英]How to replace tf.graph and tf.session from tensorflow 1.x with appropriate code in tensorflow 2?

我正在将代码从 tensorflow 1.x 迁移到 tensorflow 2.7。 我重新训练了模型并用新模型替换了旧模型,但是,与旧模型相比,它们非常慢,我认为原因是现在为发送到服务器的每个输入加载模型,而在 tf 1.x 中,使用了会话和图表,因此模型仅加载一次,因此减少了计算量。 这是用于在 tf 1 中加载模型的代码。

set_session(sess)
englishToHindiEncoder=load_model("models/englishToHindiEncoder.h5",compile=False)
englishToHindiEncoder._make_predict_function()
englishToHindiDecoder=load_model("models/englishToHindiDecoder.h5",compile=False)
englishToHindiDecoder._make_predict_function()
graph = tf.compat.v1.get_default_graph()

此外,下面是调用和使用这些模型的代码。

def decode_sequence(vector):
  revhindiDict = dict((i, char) for char, i in hindiDict.items())
  global sess
  global graph
  states_value = np.zeros((1, 1, 82))
  with graph.as_default():
    set_session(sess)
    states_value = englishToHindiEncoder.predict(vector)
  target_seq = np.zeros((1, 1, 82))
  target_seq[0, 0, hindiDict['\t']] = 1

我想知道如何删除这些代码行并用相应的 tensorflow 2.0 代码替换它们。 需要删除的行如下:

set_session(sess)
graph = tf.compat.v1.get_default_graph()

with graph.as_default():
   set_session(sess)

因为这些行现在在 tf 2 中已过时。一种方法是直接删除这些行,但它会显着增加批量发送请求时的处理时间。 所以,我想知道在 tensorflow 2.0 中替换那些会话和图形代码并重写它的最优化和最正确的方法。

在 Tensorflow 2.X 中,要在图形模式下运行代码,您可以使用@tf.function包装函数。

例如:

@tf.function
def decode_sequence(vector):
  revhindiDict = dict((i, char) for char, i in hindiDict.items())
  #global sess
  #global graph
  states_value = np.zeros((1, 1, 82))
  #with graph.as_default():
    #set_session(sess)
  states_value = englishToHindiEncoder.predict(vector)
  target_seq = np.zeros((1, 1, 82))
  target_seq[0, 0, hindiDict['\t']] = 1

有关详细信息,请参阅此文档

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM