简体   繁体   中英

Retrain Frozen Graph in Tensorflow 2.x

I have managed this implementation on retraining frozen graph in tensorflow 1 according to this wonderful detail topic . Basically, the methodology is described:

  1. Load frozen model
  2. Replace the constant frozen node with variable node .
  3. The newly replaced variable node then will be redirected to the corresponding output of the frozen node.

This works in tensorflow 1.x by checking the tf.compat.v1.trainable_variables . However, in tensorflow 2.x, it can't work anymore.

Below is the code snippet:

1/ Load frozen model

frozen_path = '...'
detection_graph = tf.Graph()
with detection_graph.as_default():
    od_graph_def = tf.compat.v1.GraphDef()
    with tf.compat.v1.io.gfile.GFile(frozen_path, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.graph_util.import_graph_def(od_graph_def, name='')

2/ Create a clone

with detection_graph.as_default():
    const_var_name_pairs = {}
    probable_variables = [op for op in detection_graph.get_operations() if op.type == "Const"]
    available_names = [op.name for op in detection_graph.get_operations()]
    for op in probable_variables:
        name = op.name
        if name+'/read' not in available_names:
            continue
        tensor = detection_graph.get_tensor_by_name('{}:0'.format(name))
        with tf.compat.v1.Session() as s:
            tensor_as_numpy_array = s.run(tensor)
        var_shape = tensor.get_shape()
        # Give each variable a name that doesn't already exist in the graph
        var_name = '{}_turned_var'.format(name)
        var = tf.Variable(name=var_name, dtype=op.outputs[0].dtype, initial_value=tensor_as_numpy_array,trainable=True, shape=var_shape)
        const_var_name_pairs[name] =  var_name

3/ Relace frozen node by Graph Editor

import graph_def_editor as ge
ge_graph = ge.Graph(detection_graph.as_graph_def())
name_to_op = dict([(n.name, n) for n in ge_graph.nodes])
for const_name, var_name in const_var_name_pairs.items():
    const_op = name_to_op[const_name+'/read']
    var_reader_op = name_to_op[var_name + '/Read/ReadVariableOp']
    ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
detection_training_graph = ge_graph.to_tf_graph()
with detection_training_graph.as_default():
    writer = tf.compat.v1.summary.FileWriter('remap', detection_training_graph )
    writer.close

The problem was my Graph Editor when I import the tf.graph_def instead of the original tf.graph that has Variables.

Quickly solve by fixing step 3

Sol1: Using Graph Editor

ge_graph = ge.Graph(detection_graph)
for const_name, var_name in const_var_name_pairs.items():
    const_op = ge_graph._node_name_to_node[const_name+'/read']
    var_reader_op = ge_graph._node_name_to_node[var_name+'/Read/ReadVariableOp']
    ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))

However, this requires disable eager execution. To work around with eager execution, you should attach the MetaGraphDef to Graph Editor as below

with detection_graph.as_default():
    meta_saver = tf.compat.v1.train.Saver()
    meta = meta_saver.export_meta_graph()
ge_graph = ge.Graph(detection_graph,collections=ge.graph._extract_collection_defs(meta))

However, this is the trickest to make the model trainable in tf2.x Instead of using Graph Editor to export directly the graph, we should export ourselves. The reason is that the Graph Editor make the Variables data type to be resources . Therefore, we should export the graph as graphdef and import the variable def to the graph:

test_graph = tf.Graph()
with test_graph.as_default():
    tf.import_graph_def(ge_graph.to_graph_def(), name="")
    for var_name in ge_graph.variable_names:
        var = ge_graph.get_variable_by_name(var_name)
        ret = variable_pb2.VariableDef()
        ret.variable_name = var._variable_name
        ret.initial_value_name = var._initial_value_name
        ret.initializer_name = var._initializer_name
        ret.snapshot_name = var._snapshot_name
        ret.trainable = var._trainable
        ret.is_resource = True
        tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
        test_graph.add_to_collections(var.collection_names, tf_var)

Sol2: Manually map by Graphdef

with detection_graph.as_default() as graph:
    training_graph_def = remap_input_node(detection_graph.as_graph_def(),const_var_name_pairs)
    current_var = (tf.compat.v1.trainable_variables())
    assert len(current_var)>0, "no training variables"


detection_training_graph = tf.Graph()
with detection_training_graph.as_default():
    tf.graph_util.import_graph_def(training_graph_def, name='')
    for var in current_var:
        ret = variable_pb2.VariableDef()
        ret.variable_name = var.name
        ret.initial_value_name = var.name[:-2] + '/Initializer/initial_value:0'
        ret.initializer_name = var.name[:-2] + '/Assign'
        ret.snapshot_name = var.name[:-2] + '/Read/ReadVariableOp:0'
        ret.trainable = True
        ret.is_resource = True
        tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
        detection_training_graph.add_to_collections({'trainable_variables', 'variables'}, tf_var)
    current_var = (tf.compat.v1.trainable_variables())
    assert len(current_var)>0, "no training variables"

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM