简体   繁体   English

TensorFlow2.0 中的 XLA - 冻结模型?

[英]XLA in TensorFlow2.0 - frozen model?

I was following the offcial guide for XLA AOT compilation ( https://www.tensorflow.org/xla/tfcompile ), and compiling the examples works just fine (inside aot/tests).我正在遵循 XLA AOT 编译的官方指南 ( https://www.tensorflow.org/xla/tfcompile ),并且编译示例工作得很好(在 aot/tests 中)。

But then I wanted to compile some slightly bigger models, and a problem arises: if XLA AOT requires a frozen graph as input (as I understand from the guide) and frozen graphs are not supported anymore in TensorFlow 2, what input does XLA expect now?但是后来我想编译一些稍大的模型,并且出现了一个问题:如果 XLA AOT 需要一个冻结图作为输入(我从指南中了解到)并且 TensorFlow 2 不再支持冻结图,那么 XLA 现在期望什么输入?

It seems like there are still ways to freeze a graph in TensorFlow 2. I followed this post to create a frozen graph and it worked to compile it afterward: https://leimao.github.io/blog/Save-Load-Inference-From-TF2-Frozen-Graph/似乎仍然有办法在 TensorFlow 2 中冻结图。我按照这篇文章创建了一个冻结图,然后它可以编译它: https : //leimao.github.io/blog/Save-Load-Inference- From-TF2-Frozen-Graph/

# Convert Keras model to ConcreteFunction
full_model = tf.function(lambda x: model(x))
full_model = full_model.get_concrete_function(
    tf.TensorSpec(model.inputs[0].shape, model.inputs[0].dtype))

# Get frozen ConcreteFunction
frozen_func = convert_variables_to_constants_v2(full_model)
frozen_func.graph.as_graph_def()

layers = [op.name for op in frozen_func.graph.get_operations()]
print("-" * 50)
print("Frozen model layers: ")
for layer in layers:
    print(layer)

print("-" * 50)
print("Frozen model inputs: ")
print(frozen_func.inputs)
print("Frozen model outputs: ")
print(frozen_func.outputs)

# Save frozen graph from frozen ConcreteFunction to hard drive
tf.io.write_graph(graph_or_graph_def=frozen_func.graph,
                  logdir="./frozen_models",
                  name="frozen_graph.pb",
                  as_text=False)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM