简体   繁体   中英

Transfer learning/ retraining with TensorFlow Estimators

I have been unable to figure out how to use transfer learning/last layer retraining with the new TF Estimator API .

The Estimator requires a model_fn which contains the architecture of the network, and training and eval ops, as defined in the documentation . An example of a model_fn using a CNN architecture ishere .

If I want to retrain the last layer of, for example, the inception architecture, I'm not sure whether I will need to specify the whole model in this model_fn , then load the pre-trained weights, or whether there is a way to use the saved graph as is done in the 'traditional' approach (example here ).

This has been brought up as an issue , but is still open and the answers are unclear to me.

It is possible to load the metagraph during model definition and use SessionRunHook to load the weights from a ckpt file.

def model(features, labels, mode, params):
    # Create the graph here

    return tf.estimator.EstimatorSpec(mode, 
            predictions,
            loss,
            train_op,
            training_hooks=[RestoreHook()])

The SessionRunHook can be:

class RestoreHook(tf.train.SessionRunHook):

    def after_create_session(self, session, coord=None):
        if session.run(tf.train.get_or_create_global_step()) == 0:
            # load weights here

This way, the weights are loaded in first step and saved during training in model checkpoints.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM