[英]How to continue training with checkpoints using object_detector.EfficientDetLite4Spec tensorflow lite
Preciously I have set my EfficientDetLite4 model "grad_checkpoint=true" in config.yaml.我已经在 config.yaml 中设置了我的 EfficientDetLite4 模型“grad_checkpoint=true” 。 And it had successfully generated some checkpoints.
它已经成功地生成了一些检查点。 However, I can't figure out how to use these checkpoints when I want to continue training based on them.
但是,当我想继续基于它们进行培训时,我无法弄清楚如何使用这些检查点。
Every time I train the model it just start from the beginning, not from my checkpoints.每次我训练模型时,它都是从头开始,而不是从我的检查点开始。
The following picture shows my colab file system structure:下图是我的colab文件系统结构:
<img src="https://i.stack.imgur.com/8EhPx.jpg"/>
my colab file system structure我的 colab 文件系统结构
The following picture shows where my checkpoints store:下图显示了我的检查点存储的位置:
<img src="https://i.stack.imgur.com/Ve5al.jpg"/>
model file system here模型文件系统在这里
The following code shows how I configure the model and how I train with the model.以下代码显示了我如何配置模型以及如何使用模型进行训练。
import numpy as np
import os
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
import tensorflow as tf
assert tf.__version__.startswith('2')
tf.get_logger().setLevel('ERROR')
from absl import logging
logging.set_verbosity(logging.ERROR)
train_data, validation_data, test_data =
object_detector.DataLoader.from_csv('csv_path')
spec = object_detector.EfficientDetLite4Spec(
uri='/content/model',
model_dir='/content/drive/MyDrive/MathSymbolRecognition/CheckPoints/',
hparams='grad_checkpoint=true,strategy=gpus',
epochs=50, batch_size=3,
steps_per_execution=1, moving_average_decay=0,
var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
tflite_max_detections=25, strategy=spec_strategy
)
model = object_detector.create(train_data, model_spec=spec, batch_size=3,
train_whole_model=True, validation_data=validation_data)
I ran into the same problem and found out that the model_dir
we pass to the TFLite model Maker's object detector API is only used for saving the model's weights: that's why the API never restores from checkpoints.我遇到了同样的问题,发现我们传递给
model_dir
模型制造商的对象检测器 API 的 model_dir仅用于保存模型的权重:这就是 API 从不从检查点恢复的原因。
Having a look at the source code of this API, I noticed it internally uses the standard model.compile
and model.fit
functions and it saves the model's weights through the callbacks
parameter of model.fit
.查看此 API 的源代码,我注意到它在内部使用标准
model.compile
和model.fit
函数,并通过model.fit
的callbacks
参数保存模型的权重。
This means that, provided that we can get the interal keras model, we can just restore our checkpoints by using model.load_weights
!这意味着,只要我们可以获得内部 keras 模型,我们就可以使用
model.load_weights
来恢复我们的检查点!
These are the links to the source code if you want to know more about what some of the functions I use below do:如果您想了解更多关于我在下面使用的某些功能的作用,这些是指向源代码的链接:
This is the code:这是代码:
#Useful imports
import tensorflow as tf
from tflite_model_maker.config import QuantizationConfig
from tflite_model_maker.config import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import object_detector
from tflite_model_maker.object_detector import DataLoader
#Import the same libs that TFLiteModelMaker interally uses
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train
from tensorflow_examples.lite.model_maker.third_party.efficientdet.keras import train_lib
#Setup variables
batch_size = 6 #or whatever batch size you want
epochs = 50
checkpoint_dir = "/content/..." #whatever your checkpoint directory is
#Create whichever object detector's spec you want
spec = object_detector.EfficientDetLite4Spec(
model_name='efficientdet-lite4',
uri='https://tfhub.dev/tensorflow/efficientdet/lite4/feature-vector/2',
hparams='', #enable grad_checkpoint=True if you want
model_dir=checkpoint_dir,
epochs=epochs,
batch_size=batch_size,
steps_per_execution=1,
moving_average_decay=0,
var_freeze_expr='(efficientnet|fpn_cells|resample_p6)',
tflite_max_detections=25,
strategy=None,
tpu=None,
gcp_project=None,
tpu_zone=None,
use_xla=False,
profile=False,
debug=False,
tf_random_seed=111111,
verbose=1
)
#Load you datasets
train_data, validation_data, test_data = object_detector.DataLoader.from_csv('/path/to/csv.csv')
#Create the object detector
detector = object_detector.create(train_data,
model_spec=spec,
batch_size=batch_size,
train_whole_model=True,
validation_data=validation_data,
epochs = epochs,
do_train = False
)
"""
From here on we use internal/"private" functions of the API,
you can tell because the methods's names begin with an underscore
"""
#Convert the datasets for training
train_ds, steps_per_epoch, _ = detector._get_dataset_and_steps(train_data, batch_size, is_training = True)
validation_ds, validation_steps, val_json_file = detector._get_dataset_and_steps(validation_data, batch_size, is_training = False)
#Get the interal keras model
model = detector.create_model()
#Copy what the API interally does as setup
config = spec.config
config.update(
dict(
steps_per_epoch=steps_per_epoch,
eval_samples=batch_size * validation_steps,
val_json_file=val_json_file,
batch_size=batch_size
)
)
train.setup_model(model, config) #This is the model.compile call basically
model.summary()
"""
Here we restore the weights
"""
#Load the weights from the latest checkpoint.
#In my case:
#checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/"
#specific_checkpoint_dir = "/content/drive/My Drive/Colab Notebooks/checkpoints_heavy/ckpt-35"
try:
#Option A:
#load the weights from the last successfully completed epoch
latest = tf.train.latest_checkpoint(checkpoint_dir)
#Option B:
#load the weights from a specific checkpoint
#latest = specific_checkpoint_dir
completed_epochs = int(latest.split("/")[-1].split("-")[1]) #the epoch the training was at when the training was last interrupted
model.load_weights(latest)
print("Checkpoint found {}".format(latest))
except Exception as e:
print("Checkpoint not found: ", e)
"""
Optional step.
Add callbacks that get executed at the end of every N
epochs: in this case I want to log the training results to tensorboard.
"""
#tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=tensorboard_dir, histogram_freq=1)
#callbacks = train_lib.get_callbacks(config.as_dict(), validation_ds)
#callbacks.append(tensorboard_callback)
"""
Train the model
"""
model.fit(
train_ds,
epochs=epochs,
initial_epoch=completed_epochs,
steps_per_epoch=steps_per_epoch,
validation_data=validation_ds,
validation_steps=validation_steps,
callbacks=train_lib.get_callbacks(config.as_dict(), validation_ds) #This is for saving checkpoints at the end of every epoch
)
"""
Save/export the trained model
Tip: for integer quantization you simply have to NOT SPECIFY
the quantization_config parameter of the detector.export method
"""
export_dir = "/content/..." #save the tflite wherever you want
quant_config = QuantizationConfig.for_float16() #or whatever quantization you want
detector.model = model #inject our trained model into the object detector
detector.export(export_dir = export_dir, tflite_filename='model.tflite', quantization_config = quant_config)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.