简体   繁体   中英

Wrapping a frozen TensorFlow pb in a tf.keras Model

I am trying to use a frozen, pretrained, DeepLabv3 model in a larger tf.keras training pipeline, but have been having trouble figuring out how to use it as a tf.keras Model. I am trying to use tf.keras as I feel there would be a slowdown using a feed_dict (the only way I know of to use a frozen graph) in the middle of multiple forward passes. The deeplab model referenced in the code below is built in regular keras (as opposed to tf.contrib.keras)

from keras import backend as K

# Create, compile and train model...

frozen_graph = freeze_session(K.get_session(),
                        output_names=[out.op.name for out in deeplab.outputs])
tf.train.write_graph(frozen_graph, "./", "my_model.pb", as_text=False)
graph = load_graph("my_model.pb")

# We can verify that we can access the list of operations in the graph
for op in graph.get_operations():
    print(op.name)
    # prefix/Placeholder/inputs_placeholder
    # ...
    # prefix/Accuracy/predictions

# We access the input and output nodes 
x = graph.get_tensor_by_name("prefix/input_1:0")
y = graph.get_tensor_by_name("prefix/bilinear_upsampling_2/ResizeBilinear:0")

# We launch a Session
with tf.Session(graph=graph) as sess:
    print(graph)
    model2 = models.Model(inputs=x,outputs=y)
    model2.summary()

and i get an error

ValueError: Input tensors to a Model must come from `tf.layers.Input`. Received: Tensor("prefix/input_1:0", shape=(?, 512, 512, 3), dtype=float32) (missing previous layer metadata).

I feel like I've seen others replace the input tensor with an Input Layer to trick tf.keras into building the graph, but after a few hours I am feeling stuck. Any help would be appreciated!

You can recreate the model object from its config . See the from_config method here https://keras.io/models/about-keras-models/ .

The config is stored and loaded back by the save_model/load_model functions . I am not familiar with freeze_session .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM