简体   繁体   English

我如何在每个时期保存检查点并加载随机保存的检查点以继续训练

[英]How I can save checkpoint every epoch and load an random saved checkpoint to continue training

Can you help me on my code to: save the model (architecture and weights) every epoch, and how I can continue to training my model from the 5th checkpoint like training epoch from 1 to 25 without making checkpoints(5th model I have saved). Can you help me on my code to: save the model (architecture and weights) every epoch, and how I can continue to training my model from the 5th checkpoint like training epoch from 1 to 25 without making checkpoints(5th model I have saved) .

classifier = Sequential()

classifier.add(Conv2D(6, (3, 3), input_shape = (30, 30, 3), data_format="channels_last", activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))

classifier.add(Conv2D(6, (3, 3), activation = 'relu'))
classifier.add(MaxPooling2D(pool_size = (2, 2)))

classifier.add(Flatten())

classifier.add(Dense(units = 128, activation = 'relu'))
classifier.add(Dense(units = 64, activation = 'relu'))
classifier.add(Dense(units = 1, activation = 'sigmoid'))

opt = Adam(learning_rate = 0.001, beta_1 = 0.9, beta_2 = 0.999, epsilon = 1e-08, decay = 0.0)
classifier.compile(optimizer = opt, loss = 'binary_crossentropy', metrics = ['accuracy', precision, recall, fmeasure])

from keras.preprocessing.image import ImageDataGenerator

train_datagen = ImageDataGenerator(rescale = 1./255,
                                   horizontal_flip = True,
                                   vertical_flip = True,
                                   rotation_range = 180)

validation_datagen = ImageDataGenerator(rescale = 1./255)

training_set = train_datagen.flow_from_directory('/home/dataset/training_set',
                                                 target_size = (30, 30),
                                                 batch_size = 32,
                                                 class_mode = 'binary')

validation_set = validation_datagen.flow_from_directory('/home/dataset/validation_set',
                                                        target_size = (30, 30),
                                                        batch_size = 32,
                                                        class_mode = 'binary')

history = classifier.fit_generator(training_set,
                                   steps_per_epoch = 208170,
                                   epochs = 15,
                                   validation_data = validation_set,
                                   validation_steps = 89140)

I'm assuming you mean you want to save your model and weights after every epoch and then at a later stage, load the model and weights saved after the fifth epoch.我假设你的意思是你想在每个时期之后保存你的 model 和权重,然后在稍后阶段,加载 model 和第五个时期之后保存的权重。

You can use the SaveModel format in TensorFlow in general like this:您通常可以像这样使用 TensorFlow 中的 SaveModel 格式:

classifier.save()

This will save the architecture, weights, info regarding the optimizer and the configurations you set up in compile()这将保存架构、权重、有关优化器的信息以及您在compile()中设置的配置

Since you're using fit_generator, you can just use ModelCheckpoint() to save your models like this:由于您使用的是 fit_generator,因此您可以使用ModelCheckpoint()来保存模型,如下所示:

from keras.callbacks import ModelCheckpoint

checkpoint = ModelCheckpoint(path_to_save_to, save_freq = 'epoch', 
                             save_weights_only = False)

history = classifier.fit_generator(training_set,
                                   steps_per_epoch = 208170,
                                   epochs = 15,
                                   validation_data = validation_set,
                                   validation_steps = 89140,
                                    callbacks = [checkpoint])

You can format the path so that it will save the model with epoch/loss details like this path_name + '-{epoch:02d}-{val_loss:.2f}.h5'您可以格式化路径,以便它将保存 model 的纪元/损失详细信息,例如path_name + '-{epoch:02d}-{val_loss:.2f}.h5'

To load the fifth checkpoint, do this:要加载第五个检查点,请执行以下操作:

from keras.models import load_model
classifier = load_model(path_to_fifth_checkpoint)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM