简体   繁体   中英

Does tensorflow re-initialize weights when training in a for loop?

I'm training a model within a for loop, because...I can. I know there are alternative like tf.Dataset API with generators to stream data from disk, but my question is on the specific case of a loop.

Does TF re-initialize weights of the model at the beginning of each loop? Or does the initialization only occurs the first time the model is instantiated?

EDIT:

for msn in LIST:

    data = pd.read_parquet(
        "03 - Data",
        engine='pyarrow')
    data = data[column_order]
    data.rename(columns={"Flight_Id_Int":"Flight_Id"}, inplace=True)     
    
    
    """ DATA PREPARATION AND FORMATING """
    data_clean = clean_and_prepare(data, SEQ_LEN, input_type=model_type, smooth=True)
        
    # To keep the chonological order of flight we don't random shuffle   
    train_idx = np.arange(0, int(len(data_clean)*0.9))
    test_idx = np.arange(int(len(data_clean)*0.9), len(data_clean))

    
    train_df = tf.data.Dataset.from_tensor_slices(
        (data_clean[train_idx], data_clean[train_idx])
        ).batch(BATCH_SIZE)
    
    test_df = tf.data.Dataset.from_tensor_slices(
        (data_clean[test_idx], data_clean[test_idx])
        ).batch(BATCH_SIZE)


    """ MODEL TRAINING """
    history = model.fit(train_df,
                epochs=EPOCHS,
                validation_data=(test_df),
                callbacks=[tf.keras.callbacks.EarlyStopping(
                    monitor="val_loss",
                    patience=15,
                    mode="min",
                    restore_best_weights = True)])
    
    plot_train_history(history, "Autoencorder {0} - MSN: {1}".format(model_type, msn))

Weights are initialized when the layers are defined (before fit ). It does not re-initialize weights afterward - even if you call fit multiple times.

To show this is the case, I plotted the decision boundary at regular training epochs (by calling fit and then predict ):

在此处输入图像描述

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM