I have managed this implementation on retraining frozen graph in tensorflow 1 according to this wonderful detail topic . Basically, the methodology is described:
constant frozen node
with variable node
. This works in tensorflow 1.x by checking the tf.compat.v1.trainable_variables
. However, in tensorflow 2.x, it can't work anymore.
Below is the code snippet:
1/ Load frozen model
frozen_path = '...'
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.compat.v1.GraphDef()
with tf.compat.v1.io.gfile.GFile(frozen_path, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.graph_util.import_graph_def(od_graph_def, name='')
2/ Create a clone
with detection_graph.as_default():
const_var_name_pairs = {}
probable_variables = [op for op in detection_graph.get_operations() if op.type == "Const"]
available_names = [op.name for op in detection_graph.get_operations()]
for op in probable_variables:
name = op.name
if name+'/read' not in available_names:
continue
tensor = detection_graph.get_tensor_by_name('{}:0'.format(name))
with tf.compat.v1.Session() as s:
tensor_as_numpy_array = s.run(tensor)
var_shape = tensor.get_shape()
# Give each variable a name that doesn't already exist in the graph
var_name = '{}_turned_var'.format(name)
var = tf.Variable(name=var_name, dtype=op.outputs[0].dtype, initial_value=tensor_as_numpy_array,trainable=True, shape=var_shape)
const_var_name_pairs[name] = var_name
3/ Relace frozen node by Graph Editor
import graph_def_editor as ge
ge_graph = ge.Graph(detection_graph.as_graph_def())
name_to_op = dict([(n.name, n) for n in ge_graph.nodes])
for const_name, var_name in const_var_name_pairs.items():
const_op = name_to_op[const_name+'/read']
var_reader_op = name_to_op[var_name + '/Read/ReadVariableOp']
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
detection_training_graph = ge_graph.to_tf_graph()
with detection_training_graph.as_default():
writer = tf.compat.v1.summary.FileWriter('remap', detection_training_graph )
writer.close
The problem was my Graph Editor
when I import the tf.graph_def
instead of the original tf.graph
that has Variables.
Quickly solve by fixing step 3
Sol1: Using Graph Editor
ge_graph = ge.Graph(detection_graph)
for const_name, var_name in const_var_name_pairs.items():
const_op = ge_graph._node_name_to_node[const_name+'/read']
var_reader_op = ge_graph._node_name_to_node[var_name+'/Read/ReadVariableOp']
ge.swap_outputs(ge.sgv(const_op), ge.sgv(var_reader_op))
However, this requires disable eager execution. To work around with eager execution, you should attach the MetaGraphDef
to Graph Editor
as below
with detection_graph.as_default():
meta_saver = tf.compat.v1.train.Saver()
meta = meta_saver.export_meta_graph()
ge_graph = ge.Graph(detection_graph,collections=ge.graph._extract_collection_defs(meta))
However, this is the trickest to make the model trainable in tf2.x Instead of using Graph Editor
to export directly the graph, we should export ourselves. The reason is that the Graph Editor
make the Variables data type to be resources
. Therefore, we should export the graph as graphdef and import the variable def to the graph:
test_graph = tf.Graph()
with test_graph.as_default():
tf.import_graph_def(ge_graph.to_graph_def(), name="")
for var_name in ge_graph.variable_names:
var = ge_graph.get_variable_by_name(var_name)
ret = variable_pb2.VariableDef()
ret.variable_name = var._variable_name
ret.initial_value_name = var._initial_value_name
ret.initializer_name = var._initializer_name
ret.snapshot_name = var._snapshot_name
ret.trainable = var._trainable
ret.is_resource = True
tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
test_graph.add_to_collections(var.collection_names, tf_var)
Sol2: Manually map by Graphdef
with detection_graph.as_default() as graph:
training_graph_def = remap_input_node(detection_graph.as_graph_def(),const_var_name_pairs)
current_var = (tf.compat.v1.trainable_variables())
assert len(current_var)>0, "no training variables"
detection_training_graph = tf.Graph()
with detection_training_graph.as_default():
tf.graph_util.import_graph_def(training_graph_def, name='')
for var in current_var:
ret = variable_pb2.VariableDef()
ret.variable_name = var.name
ret.initial_value_name = var.name[:-2] + '/Initializer/initial_value:0'
ret.initializer_name = var.name[:-2] + '/Assign'
ret.snapshot_name = var.name[:-2] + '/Read/ReadVariableOp:0'
ret.trainable = True
ret.is_resource = True
tf_var = tf.Variable(variable_def=ret,dtype=tf.float32)
detection_training_graph.add_to_collections({'trainable_variables', 'variables'}, tf_var)
current_var = (tf.compat.v1.trainable_variables())
assert len(current_var)>0, "no training variables"
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.