[英]How to replace tf.graph and tf.session from tensorflow 1.x with appropriate code in tensorflow 2?
我正在将代码从 tensorflow 1.x 迁移到 tensorflow 2.7。 我重新训练了模型并用新模型替换了旧模型,但是,与旧模型相比,它们非常慢,我认为原因是现在为发送到服务器的每个输入加载模型,而在 tf 1.x 中,使用了会话和图表,因此模型仅加载一次,因此减少了计算量。 这是用于在 tf 1 中加载模型的代码。
set_session(sess)
englishToHindiEncoder=load_model("models/englishToHindiEncoder.h5",compile=False)
englishToHindiEncoder._make_predict_function()
englishToHindiDecoder=load_model("models/englishToHindiDecoder.h5",compile=False)
englishToHindiDecoder._make_predict_function()
graph = tf.compat.v1.get_default_graph()
此外,下面是调用和使用这些模型的代码。
def decode_sequence(vector):
revhindiDict = dict((i, char) for char, i in hindiDict.items())
global sess
global graph
states_value = np.zeros((1, 1, 82))
with graph.as_default():
set_session(sess)
states_value = englishToHindiEncoder.predict(vector)
target_seq = np.zeros((1, 1, 82))
target_seq[0, 0, hindiDict['\t']] = 1
我想知道如何删除这些代码行并用相应的 tensorflow 2.0 代码替换它们。 需要删除的行如下:
set_session(sess)
graph = tf.compat.v1.get_default_graph()
和
with graph.as_default():
set_session(sess)
因为这些行现在在 tf 2 中已过时。一种方法是直接删除这些行,但它会显着增加批量发送请求时的处理时间。 所以,我想知道在 tensorflow 2.0 中替换那些会话和图形代码并重写它的最优化和最正确的方法。
在 Tensorflow 2.X 中,要在图形模式下运行代码,您可以使用@tf.function
包装函数。
例如:
@tf.function
def decode_sequence(vector):
revhindiDict = dict((i, char) for char, i in hindiDict.items())
#global sess
#global graph
states_value = np.zeros((1, 1, 82))
#with graph.as_default():
#set_session(sess)
states_value = englishToHindiEncoder.predict(vector)
target_seq = np.zeros((1, 1, 82))
target_seq[0, 0, hindiDict['\t']] = 1
有关详细信息,请参阅此文档。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.