简体   繁体   中英

How to continue training an object detection model using Tensorflow Object Detection API?

I'm using Tensorflow Object Detection API to train an object detection model using transfer learning. Specifically, I'm using ssd_mobilenet_v1_fpn_coco from the model zoo , and using the sample pipeline provided , having of course replaced the placeholders with actual links to my training and eval tfrecords and labels.

I was able able to successfully train a model on my ~5000 images (and corresponding bounding boxes) using the above pipeline (I'm mainly using Google's ML Engine on TPU, if revelant).

Now, I prepared an additional ~2000 images, and would like continue training my model with those new images, without restarting from scratch (training the initial model took ~6h of TPU time). How can I do that?

You have two options, in both you need to change the input_path of the train_input_reader of your new dataset:

  1. When specifying a checkpoint to fine-tune in the training configuration, specify the checkpoint of your trained model
train_config{
    fine_tune_checkpoint: <path_to_your_checkpoint>
    fine_tune_checkpoint_type: "detection"
    load_all_detection_checkpoint_vars: true
}
  1. Simply keep using the same configuration (except the train_input_reader ) with the same model_dir of your previous model. That way, the API will create a graph and will check whether a checkpoint already exists in model_dir and fits the graph. If so - it will restore it and continue training it.

Edit: fine_tune_checkpoint_type was previously set as true by mistake, while it should be "detection" or "classification" in general, and "detection" in this specific case. Thanks Krish for noticing.

我没有在新数据集上重新训练对象检测模型,但看起来增加配置文件中的训练步骤train_config.num_steps并且在 tfrecord 文件中添加图像应该足够了。

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM