I have been unable to figure out how to use transfer learning/last layer retraining with the new TF Estimator API .
The Estimator
requires a model_fn
which contains the architecture of the network, and training and eval ops, as defined in the documentation . An example of a model_fn
using a CNN architecture ishere .
If I want to retrain the last layer of, for example, the inception architecture, I'm not sure whether I will need to specify the whole model in this model_fn
, then load the pre-trained weights, or whether there is a way to use the saved graph as is done in the 'traditional' approach (example here ).
This has been brought up as an issue , but is still open and the answers are unclear to me.
It is possible to load the metagraph during model definition and use SessionRunHook to load the weights from a ckpt file.
def model(features, labels, mode, params):
# Create the graph here
return tf.estimator.EstimatorSpec(mode,
predictions,
loss,
train_op,
training_hooks=[RestoreHook()])
The SessionRunHook can be:
class RestoreHook(tf.train.SessionRunHook):
def after_create_session(self, session, coord=None):
if session.run(tf.train.get_or_create_global_step()) == 0:
# load weights here
This way, the weights are loaded in first step and saved during training in model checkpoints.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.