简体   繁体   中英

How to retrain with weights in Keras

I was training a model in Colab, but, I shut down my computer and this training stoped. Every 5 epochs I save the weights. I think it is but I don't know how. How it's possible to continue the training with the weights previously saved?

Thanks.

When training a model in colab , training doesn't stop when you close you computer, it stops some time afterwards.

If you are saving the weights in colab , when colab closes everything is deleted.

If you have mounted your gdrive in colab and you save weights in gdrive , your weights will be there.

If your weights are in your gdrive you can continue training by loading your stored weights to your keras model simple by

model.load_weights('path_to_weights')

Thank you for your answer, @Ioannis Nasios. Yes, my weights are in 'gdrive'. I'm training a GAN network and I trying to figure out how to load these weights and continue the training. I saved the discriminator and generator weights and also gan_loss and discriminator_loss. Well, do I have to compile generator and discriminator networks, load weights and compile gan network with their loss? I think it could be a stupid question. It is my first time training a GAN network. Here I post the code:

# Combined network
def get_gan_network(discriminator, shape, generator, optimizer, loss):
    discriminator.trainable = False
    gan_input = Input(shape=shape)
    x = generator(gan_input)
    gan_output = discriminator(x)
    gan = Model(inputs=gan_input, outputs=[x,gan_output])
    gan.compile(loss=[loss, "binary_crossentropy"],
                loss_weights=[1., 1e-3],
                optimizer=optimizer)

    return gan

def train(x_train_lr, x_train_hr, x_test_lr, x_test_hr, epochs, batch_size, output_dir, model_save_dir, weights_save_dir):


    loss = VGG_LOSS(image_shape)  

    batch_count = int(x_train_hr.shape[0] / batch_size)
    #### SI LAS IMAGENES NO SON CUADRADAS ESTO DEBERIA CAMBIAR
    shape_lr = (image_shape[0]//downscale_factor, image_shape[1]//downscale_factor, image_shape[2])
    shape_hr = x_train_hr[0].shape
    ####
    generator = Generator(shape_lr, shape_hr).generator()
    discriminator = Discriminator(image_shape).discriminator()

    optimizer = Utils_model.get_optimizer()
    generator.compile(loss=loss.vgg_loss, optimizer=optimizer)
    discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)

    gan = get_gan_network(discriminator, shape_lr, generator, optimizer, loss.vgg_loss)

    loss_file = open(model_save_dir + '/losses.txt' , 'w+')
    loss_file.close()

    for e in range(1, epochs+1):
        print ('-'*15, 'Epoch %d' % e, '-'*15)
        for _ in tqdm(range(batch_count)):

            rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)

            image_batch_hr = x_train_hr[rand_nums]
            image_batch_lr = x_train_lr[rand_nums]

            generated_images_sr = generator.predict(image_batch_lr)

            real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
            fake_data_Y = np.random.random_sample(batch_size)*0.2

            discriminator.trainable = True

            d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
            d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
            discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

            rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
            image_batch_hr = x_train_hr[rand_nums]
            image_batch_lr = x_train_lr[rand_nums]

            gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
            discriminator.trainable = False
            gan_loss = gan.train_on_batch(image_batch_lr, [image_batch_hr,gan_Y])


        print("discriminator_loss : %f" % discriminator_loss)
        print("gan_loss :", gan_loss)
        gan_loss = str(gan_loss)

        loss_file = open(model_save_dir + 'losses.txt' , 'a')
        loss_file.write('epoch%d : gan_loss = %s ; discriminator_loss = %f\n' %(e, gan_loss, discriminator_loss) )
        loss_file.close()

        if e == 1 or e % 5 == 0:
            Utils.plot_generated_images(output_dir, e, generator, x_test_hr, x_test_lr)
            generator.save_weights(weights_save_dir + '%d_gen_weights.h5' % e)
            discriminator.save_weights(weights_save_dir + '%d_dis_weights.h5' % e)

        if e % 500 == 0 or e == epochs+1:
            generator.save(model_save_dir + 'gen_model%d.h5' % e)
            discriminator.save(model_save_dir + 'dis_model%d.h5' % e)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM