简体   繁体   中英

How to sequentially combine 2 tensorflow models?

I have 2 Tensorflow models both having the same architecture (Unet-3d). My current flow is:

Pre-processing -> Prediction from Model 1 -> Some operations -> Prediction from Model 2 -> Post-processing

The operations in between can be done in TF. Can we combine both the models with the operations in between to 1 TF graph such that the flow would look something like this:

Pre-processing -> Model 1+2 -> Post-processing

Thanks.

You can use the tf.keras functional api to achieve this, here is a toy example.

import tensorflow as tf
print('TensorFlow:', tf.__version__)

def preprocessing(tensor):
    # preform your operations
    return tensor

def some_operations(model_1_prediction):
    # preform your operations
    # assuming your operations result in a tensor
    # which has shape matching with model_2's input
    tensor = model_1_prediction
    return tensor

def post_processing(tensor):
    # preform your operations
    return tensor

def get_model(name):
    inp = tf.keras.Input(shape=[256, 256, 3])
    x = tf.keras.layers.Conv2D(64, 3, 1, 'same')(inp)
    x = tf.keras.layers.Conv2D(256, 3, 1, 'same')(x)
    x = tf.keras.layers.Conv2D(512, 3, 1, 'same')(x)
    x = tf.keras.layers.Conv2D(64, 3, 1, 'same')(x)
    x = tf.keras.layers.Conv2D(3, 3, 1, 'same')(x)
    # num_filters is set to 3 to make sure model_1's output
    # matches model_2's input.
    output = tf.keras.layers.Activation('sigmoid')(x)
    return tf.keras.Model(inp, output, name=name)

model_1 = get_model('model-1')
model_2 = get_model('model-2')


x = some_operations(model_1.output)
out = model_2(x)
model_1_2 = tf.keras.Model(model_1.input, out, name='model-1+2')

model_1_2.summary()

Output:

TensorFlow: 2.1.0-rc0
Model: "model-1+2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 256, 256, 64)      1792      
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 256, 256, 256)     147712    
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 256, 256, 512)     1180160   
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 256, 256, 64)      294976    
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 256, 256, 3)       1731      
_________________________________________________________________
activation (Activation)      (None, 256, 256, 3)       0         
_________________________________________________________________
model-2 (Model)              (None, 256, 256, 3)       1626371   
=================================================================
Total params: 3,252,742
Trainable params: 3,252,742
Non-trainable params: 0
_________________________________________________________________
​

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM