简体   繁体   中英

How can I make predictions from a trained model inside a Tensorflow input pipeline?

I am trying to train a model for emotion recognition, which uses one of VGG's layer's output as an input.

I could manage what I want by running the prediction in a first step, saving the extracted features and then using them as input to my network, but I am looking for a way to do the whole process at once.

The second model uses a concatenated array of feature maps as input (I am working with video data), so I am not able to simply wire it to the output of VGG.

I tried to use a map operation as depicted in the tf.data.dataset API documentations this way :

def trimmed_vgg16():
  vgg16 = tf.keras.applications.vgg16.VGG16(input_shape=(224,224,3))
  trimmed = tf.keras.models.Model(inputs=vgg16.get_input_at(0),
                                  outputs=vgg16.layers[-3].get_output_at(0))
  return trimmed

vgg16 = trimmed_vgg16()

def _extract_vgg_features(images, labels):
    pred = vgg16_model.predict(images, batch_size=batch_size, steps=1)
    return pred, labels

dataset = #load the dataset (image, label) as usual
dataset = dataset.map(_extract_vgg_features)

But I'm getting this error : Tensor Tensor("fc1/Relu:0", shape=(?, 4096), dtype=float32) is not an element of this graph which is pretty explicit. I'm stuck here, as I don't see a good way of inserting the trained model in the same graph and getting predictions "on the fly".

Is there a clean way of doing this or something similar ?

Edit: missed a line.
Edit2: added details

You should be able to connect the layers by first creating the vgg16 and then retrieving the output of the model as such and afterward you can use that tensor as an input to your own network.

vgg16 = tf.keras.applications.vgg16.VGG16(input_shape=(224,224,3))

network_input = vgg16.get_input_at(0)
vgg16_out = vgg16.layers[-3].get_output_at(0) # use this tensor as input to your own network

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM