簡體   English   中英

如何使用 Tensorflow 2/Keras 保存和繼續訓練具有多個模型部分的 GAN

[英]How to save and resume training a GAN with multiple model parts with Tensorflow 2/ Keras

我目前正在嘗試添加一個功能來中斷和恢復基於此示例代碼創建的 GAN 的訓練: https : //machinelearningmastery.com/how-to-develop-an-auxiliary-classifier-gan-ac-gan-from-從頭開始/

我設法讓它以一種方式工作,我將整個復合 GAN 的權重保存在 summarise_performance 函數中,該函數每 10 個時期觸發一次,如下所示:

# save all weights
filename3 = 'weights_%08d.h5' % (step+1)
gan_model.save_weights(filename3)
print('>Saved: %s and %s and %s' % (filename1, filename2, filename3))

它加載到我添加到程序開頭的一個名為 load_model 的函數中,該函數采用正常構建的 gan 架構,但將其權重更新為最新值,如下所示:

#load model from file and return startBatch number
def load_model(gan_model):
   start_batch = 0
   files = glob.glob("./weights_0*.h5")
   if(len(files) > 0 ):
       most_recent_file = files[len(files)-1]
       gan_model.load_weights(most_recent_file)
       #TODO: breaks if using more than 8 digits for batches
       startBatch = int(most_recent_file[10:18])
       if (start_batch != 0):
           print("> found existing weights; starting at batch %d" % start_batch)
   return start_batch

其中 start_batch 被傳遞給 train 函數以跳過已經完成的時期。

雖然這種減輕重量的方法確實“有效”,但我仍然認為我的方法是錯誤的,因為我發現權重數據顯然不包括 GAN 的優化器狀態,​​因此訓練不會像它那樣繼續沒有被打斷。

我發現保存進度同時保存優化器狀態的方式顯然是通過保存整個模型而不僅僅是權重來完成的

在這里我遇到了一個問題,因為在 GAN 中,我不僅訓練了一個模型,而且有 3 個模型:

  • 生成器模型 g_model
  • 判別器模型 d_model
  • 和復合 GAN 模型 gan_model

這些都是相互聯系和相互依賴的。 如果我采用天真的方法並分別保存和恢復這些零件模型中的每一個,我最終會得到 3 個獨立的脫節模型而不是 GAN

有沒有一種方法可以讓我恢復訓練,就好像沒有發生中斷一樣,可以保存和恢復整個 GAN?

如果您想恢復整個 GAN,可以考慮使用tf.train.Checkpoint

### In your training loop

checkpoint_dir = '/checkpoints'
checkpoint = tf.train.Checkpoint(gan_optimizer=gan_optimizer,
                            discriminator_optimizer=discriminator_optimizer,
                                  generator=generator,
                                  discriminator=discriminator
                                  gan_model = gan_model)
  
ckpt_manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, max_to_keep=3)
if ckpt_manager.latest_checkpoint:
    checkpoint.restore(ckpt_manager.latest_checkpoint)  
    print ('Latest checkpoint restored!!')

....
....


if (epoch + 1) % 40 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,ckpt_save_path))

### After x number of epochs, just save your generator model for inference.

generator.save('your_model.h5')

您也可以考慮完全擺脫復合模型。 是我的意思的一個例子。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM