简体   繁体   中英

How do you add new categories and training to a pretrained Inception v3 model in TensorFlow?

I'm trying to utilize a pre-trained model like Inception v3 (trained on the 2012 ImageNet data set) and expand it in several missing categories.

I have TensorFlow built from source with CUDA on Ubuntu 14.04, and the examples like transfer learning on flowers are working great. However, the flowers example strips away the final layer and removes all 1,000 existing categories, which means it can now identify 5 species of flowers, but can no longer identify pandas, for example. https://www.tensorflow.org/versions/r0.8/how_tos/image_retraining/index.html

How can I add the 5 flower categories to the existing 1,000 categories from ImageNet (and add training for those 5 new flower categories) so that I have 1,005 categories that a test image can be classified as? In other words, be able to identify both those pandas and sunflowers?

I understand one option would be to download the entire ImageNet training set and the flowers example set and to train from scratch, but given my current computing power, it would take a very long time, and wouldn't allow me to add, say, 100 more categories down the line.

One idea I had was to set the parameter fine_tune to false when retraining with the 5 flower categories so that the final layer is not stripped: https://github.com/tensorflow/models/blob/master/inception/README.md#how-to-retrain-a-trained-model-on-the-flowers-data , but I'm not sure how to proceed, and not sure if that would even result in a valid model with 1,005 categories. Thanks for your thoughts.

Unfortunately, you cannot add categories to an existing graph; you'll basically have to save a checkpoint and train that graph from that checkpoint onward.

After much learning and working in deep learning professionally for a few years now, here is a more complete answer:

The best way to add categories to an existing models (eg Inception trained on the Imagenet LSVRC 1000-class dataset) would be to perform transfer learning on a pre-trained model.

If you are just trying to adapt the model to your own data set (eg 100 different kinds of automobiles), simply perform retraining/fine tuning by following the myriad online tutorials for transfer learning, including the official one for Tensorflow.

While the resulting model can potentially have good performance, please keep in mind that the tutorial classifier code is highly un-optimized (perhaps intentionally) and you can increase performance by several times by deploying to production or just improving their code.

However, if you're trying to build a general purpose classifier that includes the default LSVRC data set (1000 categories of everyday images) and expand that to include your own additional categories, you'll need to have access to the existing 1000 LSVRC images and append your own data set to that set. You can download the Imagenet dataset online, but access is getting spotier as time rolls on. In many cases, the images are also highly outdated (check out the images for computers or phones for a trip down memory lane).

Once you have that LSVRC dataset, perform transfer learning as above but including the 1000 default categories along with your own images. For your own images, a minimum of 100 appropriate images per category is generally recommended (the more the better), and you can get better results if you enable distortions (but this will dramatically increase retraining time, especially if you don't have a GPU enabled as the bottleneck files cannot be reused for each distortion; personally I think this is pretty lame and there's no reason why distortions couldn't also be cached as a bottleneck file, but that's a different discussion and can be added to your code manually).

Using these methods and incorporating error analysis, we've trained general purpose classifiers on 4000+ categories to state-of-the-art accuracy and deployed them on tens of millions of images. We've since moved on to proprietary model design to overcome existing model limitations, but transfer learning is a highly legitimate way to get good results and has even made its way to natural language processing via BERT and other designs.

Hopefully, this helps.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM