简体   繁体   中英

Tensorflow, use a tf.estimator trained model within another tf.estimator model_fn

Is there a way to use tf.estimator trained model A in another model B?

Here is situation, Let say I have a trained 'Model A' with model_a_fn(). 'Model A' gets images as input, and outputs some vector floating values similar to MNIST classifier. And there is another 'Model B' which is defined in model_b_fn(). It also gets images as input, and needs vector output of 'Model A' while training 'Model B'.

So basically I want to train 'Model B' that need inputs as images & prediction output of 'Model A'. (No need to train 'Model A' anymore, only to get prediction output while training 'Model B')

I've tried three cases:

  1. Use estimator object('Model A') inside model_b_fn()
  2. Exported 'Model A' with tf.estimator.export_savedmodel(), and create prediction function. Passed it to model_b_fn() with params dict.
  3. Same as 2, but restore 'Model A' inside model_b_fn()

But all cases shows errors:

  1. ... must be from the same graph as ...
  2. TypeError: can't pickle _thread.RLock objects
  3. TypeError: The value of a feed cannot be a tf.Tensor object.

And here is my code I used... only attaching important parts

train_model_a.py

def model_a_fn(features, labels, mode, params):
    # ...
    # ...
    # ...
    return

def main():
    # model checkpoint location
    model_a_dir = './model_a'

    # create estimator for Model A
    model_a = tf.estimator.Estimator(model_fn=model_a_fn, model_dir=model_a_dir)

    # train Model A
    model_a.train(input_fn=lambda : input_fn_a)
    # ...
    # ...
    # ...

    # export model a
    model_a.export_savedmodel(model_a_dir, serving_input_receiver_fn=serving_input_receiver_fn)
    # exported to ./model_a/123456789
    return

if __name__ == '__main__':
    main()

train_model_b_case_1.py

# follows model_a's input format
def bypass_input_fn(x):
    features = {
        'x': x,
    }
    return features

def model_b_fn(features, labels, mode, params):
    # parse input
    inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])

    # get Model A's response
    model_a = params['model_a']
    predictions = model_a.predict(
        input_fn=lambda: bypass_input_fn(inputs)
    )
    for results in predictions:
        # Error occurs!!!
        model_a_output = results['class_id']

    # build Model B
    layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
    layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)

    # ...
    # some layers added...
    # ...

    flatten = tf.layers.flatten(prev_layer)
    layern = tf.layers.dense(10)

    # let say layern's output shape and model_a_output's output shape is same
    add_layer = tf.add(flatten, model_a_output)

    # ...
    # do more... stuff
    # ...
    return

def main():
    # load pretrained model A
    model_a_dir = './model_a'
    model_a = tf.estimator.Estimator(model_fn=model_a_fn, model_dir=model_a_dir)

    # model checkpoint location
    model_b_dir = './model_b/'

    # create estimator for Model A
    model_b = tf.estimator.Estimator(
        model_fn=model_b_fn,
        model_dir=model_b_dir,
        params={
            'model_a': model_a,
        }
    )

    # train Model B
    model_b.train(input_fn=lambda : input_fn_b)
    return

if __name__ == '__main__':
    main()

train_model_b_case_2.py

def model_b_fn(features, labels, mode, params):
    # parse input
    inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])

    # get Model A's response
    model_a_predict_fn = params['model_a_predict_fn']
    model_a_prediction = model_a_predict_fn(
        {
            'x': inputs
        }
    )
    model_a_output = model_a_prediction['output']

    # build Model B
    layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
    layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)

    # ...
    # some layers added...
    # ...

    flatten = tf.layers.flatten(prev_layer)
    layern = tf.layers.dense(10)

    # let say layern's output shape and model_a_output's output shape is same
    add_layer = tf.add(flatten, model_a_output)

    # ...
    # do more... stuff
    # ...
    return

def main():
    # load pretrained model A
    model_a_dir = './model_a/123456789'
    model_a_predict_fn = tf.contrib.predictor.from_saved_model(export_dir=model_a_dir)

    # model checkpoint location
    model_b_dir = './model_b/'

    # create estimator for Model A
    # Error occurs!!!
    model_b = tf.estimator.Estimator(
        model_fn=model_b_fn,
        model_dir=model_b_dir,
        params={
            'model_a_predict_fn': model_a_predict_fn,
        }
    )

    # train Model B
    model_b.train(input_fn=lambda : input_fn_b)
    return

if __name__ == '__main__':
    main()

train_model_b_case_3.py

def model_b_fn(features, labels, mode, params):
    # parse input
    inputs = tf.reshape(features['x'], shape=[-1, 28, 28, 1])

    # get Model A's response
    model_a_predict_fn = tf.contrib.predictor.from_saved_model(export_dir=params['model_a_dir'])
    # Error occurs!!!
    model_a_prediction = model_a_predict_fn(
        {
            'x': inputs
        }
    )
    model_a_output = model_a_prediction['output']

    # build Model B
    layer1 = tf.layers.conv2d(inputs, 32, 5, same, activation=tf.nn.relu)
    layer1 = tf.layers.max_pooling2d(layer1, pool_size=[2, 2], strides=2)

    # ...
    # some layers added...
    # ...

    flatten = tf.layers.flatten(prev_layer)
    layern = tf.layers.dense(10)

    # let say layern's output shape and model_a_output's output shape is same
    add_layer = tf.add(flatten, model_a_output)

    # ...
    # do more... stuff
    # ...
    return

def main():
    # load pretrained model A
    model_a_dir = './model_a/123456789'

    # model checkpoint location
    model_b_dir = './model_b/'

    # create estimator for Model A
    # Error occurs!!!
    model_b = tf.estimator.Estimator(
        model_fn=model_b_fn,
        model_dir=model_b_dir,
        params={
            'model_a_dir': model_a_dir,
        }
    )

    # train Model B
    model_b.train(input_fn=lambda : input_fn_b)
    return

if __name__ == '__main__':
    main()

So any idea on using trained custom tf.estimator in another tf.estimator please??

I've figured out one solution to this problem.

One can use this method if struggling with same problem.

  1. create a function which runs tensorflow.contrib.predictor.from_saved_model() -> call it 'pretrained_predictor()'
  2. inside Model B's model_fn(), call above predefined 'pretrained_predictor()' wrap it with tensorflow.py_func()

For example case, see https://github.com/moono/tf-cnn-mnist/blob/master/4_3_estimator_within_estimator.py for simple use case.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM