简体   繁体   中英

Functional API Keras alternate solution for predict_classes()

Please refer here for my previous question for background information. As per answer suggested by Nassim Ben . I trained model of two-path architecture using functional API. Now I feel stuck as I need to predict the class of each pixel. here is the code for the same:

    imgs = io.imread(test_img).astype('float').reshape(5,240,240)
    plist = []

 # create patches from an entire slice
            for img in imgs[:-1]:
                if np.max(img) != 0:
                    img /= np.max(img)
                p = extract_patches_2d(img, (33,33))
                plist.append(p)
            patches = np.array(zip(np.array(plist[0]), np.array(plist[1]), np.array(plist[2]), np.array(plist[3])))

    # predict classes of each pixel based on model
            full_pred = self.model_comp.predict_classes(patches)
            fp1 = full_pred.reshape(208,208)

But according to the github-link predict_classes() is unavailable. So my question is there any other alternative that I can try?

Nassim answer is great but I want to share with you the experience I have with similiar tasks:

  1. Never use predict_proba Keras for version. Here you could find why.
  2. Most of methods used for turning predictions into classes doesn't take into account your data statistics. In case of image segmentation - very often detecting an object is more important then detecting a background. For this reason I advise you to use a threshold obtained from a precision-recall curve for each class. In this case - you need to set a threshold value for which precision == recall (or it's as close as possible). After you obtain the thresholds - you need to write your custom function for a class prediction.

Indeed, predict_classes is not available for functionnal models as it might not make sense to use it in some cases. However, a "one liner" solution exists to this :

y_classes = keras.utils.np_utils.probas_to_classes(self.model_comp.predict(patches))

This works in keras 1.2.2, not sure about keras 2.0, I couldn't find the function in the source code. But there is really nothing shady about this, your model outputs a vector of probabilities to belonging to each class. What the function does is just take the argmax and outputs the class coresponding to the highest probability.

I hope this helps.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM