简体   繁体   English

如何使用 object_detector.EfficientDetLite4Spec tensorflow lite 继续使用检查点进行训练

[英]How to continue training with checkpoints using object_detector.EfficientDetLite4Spec tensorflow lite

Preciously I have set my EfficientDetLite4 model "grad_checkpoint=true" in config.yaml.我已经在 config.yaml 中设置了我的 EfficientDetLite4 模型“grad_checkpoint=true” And it had successfully generated some checkpoints.它已经成功地生成了一些检查点。 However, I can't figure out how to use these checkpoints when I want to continue training based on them.但是,当我想继续基于它们进行培训时,我无法弄清楚如何使用这些检查点。

Every time I train the model it just start from the beginning, not from my checkpoints.每次我训练模型时,它都是从头开始,而不是从我的检查点开始。

The following picture shows my colab file system structure:下图是我的colab文件系统结构:

 <img src="https://i.stack.imgur.com/8EhPx.jpg"/>

my colab file system structure我的 colab 文件系统结构

The following picture shows where my checkpoints store:下图显示了我的检查点存储的位置:

 <img src="https://i.stack.imgur.com/Ve5al.jpg"/>

model file system here模型文件系统在这里

The following code shows how I configure the model and how I train with the model.以下代码显示了我如何配置模型以及如何使用模型进行训练。

import numpy as np
import os

from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector

import tensorflow as tf
assert tf.__version__.startswith('2')

tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)

train_data, validation_data, test_data = 
    object_detector.DataLoader.from_csv('csv_path')

spec = object_detector.EfficientDetLite4Spec(
    uri='/content/model',
    model_dir='/content/drive/MyDrive/MathSymbolRecognition/CheckPoints/',
    hparams='grad_checkpoint=true,strategy=gpus',
    epochs=50, batch_size=3,
    steps_per_execution=1, moving_average_decay=0,
    var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
    tflite_max_detections=25, strategy=spec_strategy
)

model = object_detector.create(train_data, model_spec=spec, batch_size=3, 
    train_whole_model=True, validation_data=validation_data)

The source code is the answer !源代码就是答案!

I ran into the same problem and found out that the model_dir we pass to the TFLite model Maker's object detector API is only used for saving the model's weights: that's why the API never restores from checkpoints.我遇到了同样的问题,发现我们传递给model_dir模型制造商的对象检测器 API 的 model_dir仅用于保存模型的权重:这就是 API 从不从检查点恢复的原因。

Having a look at the source code of this API, I noticed it internally uses the standard model.compile and model.fit functions and it saves the model's weights through the callbacks parameter of model.fit .查看此 API 的源代码,我注意到它在内部使用标准model.compilemodel.fit函数,并通过model.fitcallbacks参数保存模型的权重。
This means that, provided that we can get the interal keras model, we can just restore our checkpoints by using model.load_weights !这意味着,只要我们可以获得内部 keras 模型,我们就可以使用model.load_weights来恢复我们的检查点!

These are the links to the source code if you want to know more about what some of the functions I use below do:如果您想了解更多关于我在下面使用的某些功能的作用,这些是指向源代码的链接:

This is the code:这是代码:

#Useful imports
import tensorflow as tf
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
from tflite_model_maker.object_detector import DataLoader

#Import the same libs that TFLiteModelMaker interally uses
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib



#Setup variables
batch_size = 6 #or whatever batch size you want
epochs = 50
checkpoint_dir = "/content/..." #whatever your checkpoint directory is



#Create whichever object detector's spec you want
spec = object_detector.EfficientDetLite4Spec(
    model_name='efficientdet-lite4',
    uri='https://tfhub.dev/tensorflow/efficientdet/lite4/feature-vector/2', 
    hparams='', #enable grad_checkpoint=True if you want
    model_dir=checkpoint_dir, 
    epochs=epochs, 
    batch_size=batch_size,
    steps_per_execution=1, 
    moving_average_decay=0,
    var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
    tflite_max_detections=25, 
    strategy=None, 
    tpu=None, 
    gcp_project=None,
    tpu_zone=None, 
    use_xla=False, 
    profile=False, 
    debug=False, 
    tf_random_seed=111111,
    verbose=1
)



#Load you datasets
train_data, validation_data, test_data = object_detector.DataLoader.from_csv('/path/to/csv.csv')




#Create the object detector 
detector = object_detector.create(train_data, 
                                model_spec=spec, 
                                batch_size=batch_size, 
                                train_whole_model=True, 
                                validation_data=validation_data,
                                epochs = epochs,
                                do_train = False
                                )



"""
From here on we use internal/"private" functions of the API,
you can tell because the methods's names begin with an underscore
"""

#Convert the datasets for training
train_ds, steps_per_epoch, _ = detector._get_dataset_and_steps(train_data, batch_size, is_training = True)
validation_ds, validation_steps, val_json_file = detector._get_dataset_and_steps(validation_data, batch_size, is_training = False)




#Get the interal keras model    
model = detector.create_model()




#Copy what the API interally does as setup
config = spec.config
config.update(
    dict(
        steps_per_epoch=steps_per_epoch,
        eval_samples=batch_size * validation_steps,
        val_json_file=val_json_file,
        batch_size=batch_size
    )
)
train.setup_model(model, config) #This is the model.compile call basically
model.summary()




"""
Here we restore the weights
"""

#Load the weights from the latest checkpoint.
#In my case:
#checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/" 
#specific_checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/ckpt-35"
try:
  #Option A:
  #load the weights from the last successfully completed epoch
  latest = tf.train.latest_checkpoint(checkpoint_dir) 

  #Option B:
  #load the weights from a specific checkpoint
  #latest = specific_checkpoint_dir

  completed_epochs = int(latest.split("/")[-1].split("-")[1]) #the epoch the training was at when the training was last interrupted
  model.load_weights(latest)

  print("Checkpoint found {}".format(latest))
except Exception as e:
  print("Checkpoint not found: ", e)




"""
Optional step.
Add callbacks that get executed at the end of every N 
epochs: in this case I want to log the training results to tensorboard.
"""
#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, histogram_freq=1)
#callbacks = train_lib.get_callbacks(config.as_dict(), validation_ds)
#callbacks.append(tensorboard_callback)




"""
Train the model 
"""
model.fit(
    train_ds,
    epochs=epochs,
    initial_epoch=completed_epochs, 
    steps_per_epoch=steps_per_epoch,
    validation_data=validation_ds,
    validation_steps=validation_steps,
    callbacks=train_lib.get_callbacks(config.as_dict(), validation_ds) #This is for saving checkpoints at the end of every epoch
)




"""
Save/export the trained model
Tip: for integer quantization you simply have to NOT SPECIFY 
the quantization_config parameter of the detector.export method
"""
export_dir = "/content/..." #save the tflite wherever you want
quant_config = QuantizationConfig.for_float16() #or whatever quantization you want
detector.model = model #inject our trained model into the object detector
detector.export(export_dir = export_dir, tflite_filename='model.tflite', quantization_config = quant_config)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 使用 TensorFlow 进行目标检测器训练 - Object Detector Training using TensorFlow 使用Tensorflow的对象检测API用我自己的数据集训练对象检测器时出错 - Error while training an object detector with my own dataset using Tensorflow's Object Detection API TensorFlow:更改使用Supervisor进行训练时要保留的最大检查点数量? - TensorFlow: Changing the maximum number of checkpoints to keep when training using a Supervisor? 如何加载张量流模型并继续训练 - how to load a tensorflow model and continue training TFLite model maker 自定义 object 检测器训练使用 tfrecord - TFLite model maker custom object detector training using tfrecord Tensorflow 自定义 Object 检测器:model_main_tf2 未开始训练 - Tensorflow custom Object Detector: model_main_tf2 doesn't start training Tensorflow:如何获取检查点列表 - Tensorflow : how to get a list of checkpoints 在内存中序列化和反序列化Tensorflow模型并继续训练 - Serializing and deserializing Tensorflow model in memory and continue training Tensorflow:继续训练包含更多数据的图形(.pb) - Tensorflow: Continue training a graph (.pb) with more data 我可以使用 ktrain 库从检查点恢复训练吗? - Can I resume training from checkpoints using the ktrain library?
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM