简体   繁体   中英

AttributeError: 'Functional' object has no attribute 'predict_segmentation' When importing TensorFlow model Keras

I have successfully trained a Keras model like:

import tensorflow as tf
from keras_segmentation.models.unet import vgg_unet

# initaite the model
model = vgg_unet(n_classes=50, input_height=512, input_width=608)

# Train
model.train(
    train_images=train_images,
    train_annotations=train_annotations,
    checkpoints_path="/tmp/vgg_unet_1", epochs=5
)

And saved it in hdf5 format with:

tf.keras.models.save_model(model,'my_model.hdf5')

Then I load my model with

model=tf.keras.models.load_model('my_model.hdf5')

Finally I want to make a segmentation prediction on a new image with

out = model.predict_segmentation(
    inp=image_to_test,
    out_fname="/tmp/out.png"
)

I am getting the following error:

AttributeError: 'Functional' object has no attribute 'predict_segmentation'

What am I doing wrong? Is it when I am saving my model or when I am loading it?

Thanks !

predict_segmentation isn't a function available in normal Keras models. It looks like it was added after the model was created in the keras_segmentation library, which might be why Keras couldn't load it again.

I think you have 2 options for this.

  1. You could use the line from the code I linked to manually add the function back to the model.
model.predict_segmentation = MethodType(keras_segmentation.predict.predict, model)
  1. You could create a new vgg_unet with the same arguments when you reload the model, and transfer the weights from your hdf5 file to that model as suggested in the Keras documentation .
model = vgg_unet(n_classes=50, input_height=512, input_width=608)
model.load_weights('my_model.hdf5')

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM