简体   繁体   中英

Tensorflow: can I reuse part of a keras model as a operation or function?

I was thinking of re-use the lower part of tf.contrib.keras.applications.ResNet50 and port its output to my layers. I did:

tf.contrib.keras.backend.set_learning_phase(True)                              

tf_dataset = tf.contrib.data.Dataset.from_tensor_slices(\                      
        np.arange(seq_index.size('velodyne')))\                                
        .shuffle(1000)\                                                        
        .repeat()\                                                             
        .batch(10)\                                                            
        .map(partial(ing_dataset_map, seq_index))                              

iterator = tf_dataset.make_initializable_iterator()                            
next_elements = iterator.get_next()                                            

def model(input_maps):                                                         
    input_maps = tf.reshape(input_maps, shape = [-1, 200, 200, 3])             
    resnet = tf.contrib.keras.applications.ResNet50(                           
            include_top = False, weights = None,                            
            input_shape = (200, 200, 3), pooling = None)                       

    net = resnet.apply(input_maps)                                             
    temp = tf.get_default_graph().get_tensor_by_name('activation_40/Relu:0')

    net = tf.layers.conv2d(inputs = temp,                                      
            filters = 2, kernel_size = [1, 1], padding = 'same',               
            activation = tf.nn.relu)                                           
    return net                                                                 

m = model(next_elements['input_maps'])                                         

with tf.Session() as sess:                                                     
    sess.run(iterator.initializer)                                          
    sess.run(tf.global_variables_initializer())                             

    ret = sess.run(m)                                                       

Then tensorflow will report:

You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,200,200,3]

If I directly use the output of the whole resnet.apply(input_maps) . There will be no errors. So was just wondering how this could reformed? Thank you.

Found answer by myself. Should make use of the Model functionality to create a usable graph.

outputs = []                                                                
outputs.append(tf.get_default_graph().get_tensor_by_name('activation_25/Relu:0'))
outputs.append(tf.get_default_graph().get_tensor_by_name('activation_31/Relu:0'))
inputs = resnet.input                                                       
sub_resnet = tf.contrib.keras.models.Model(inputs, outputs)                 
low_branch, high_branch = sub_resnet.apply(input_maps)                      

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM