简体   繁体   中英

Access .log files acc\loss after training model (using .fit_generator) with multiple pauses

I'm using Amazon's SageMaker Studio Lab to train a model using a certain dataset.
The code is as follow (which saves the History object in history variable):

model = tf.keras.models.load_model('best_model.hdf5')  # Every run after runtime end, use the last saved model
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])
checkpointer = ModelCheckpoint(filepath='best_model.hdf5', verbose=1, save_best_only=True)
csv_logger = CSVLogger('history.log')

history = model.fit_generator(train_generator,
                    steps_per_epoch = nb_train_samples // batch_size,
                    validation_data=validation_generator,
                    validation_steps=nb_validation_samples // batch_size,
                    epochs=30,
                    verbose=1,
                    callbacks=[csv_logger, checkpointer])

I had to make several pauses due to ending runtime, and with each pause I saved the.log file. Now after appending those.log files, I'm trying to access them using the standard accuracy and loss plotting methods:

def plot_accuracy(history,title):
    plt.title(title)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train_accuracy', 'validation_accuracy'], loc='best')
    plt.show()
def plot_loss(history,title):
    plt.title(title)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train_loss', 'validation_loss'], loc='best')
    plt.show()

But the issue is I can't seem to manage to receive a working History object file.
Among the things I tried:
I read about this possible method, of re-loading the model and trying to get it's history, but it didn't work, "'NoneType' object has no attribute 'history'"

model = tf.keras.models.load_model('best_model.hdf5')
history = model.history

Another try was using pandas package and loading the file, which generated an error "'DataFrame' object has no attribute 'history'":

history = pd.read_csv('history.log', sep=',', engine='python')

And this try generated a CSVLogger object, "'CSVLogger' object has no attribute 'history'":

history = CSVLogger('history.log')

Appreciate any help on how to recover the History object, so I can plot those results (if it's even possible?)...
Thanks.

Instead of recreating the History object, what I did was read the.log file using pandas package, read_csv method, and create a DataFrame data structure with the wanted columns and plot it. Code below:

history = pd.read_csv('history.log')
history_acc = pd.DataFrame(history, columns=["accuracy", "val_accuracy"])
history_loss = pd.DataFrame(history, columns=["loss", "val_loss"])
plot_accuracy(history_acc,'plot title...')
plot_loss(history_loss,'plot title...')

def plot_accuracy(history,title):
    plt.title(title)
    plt.plot(history)
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train_accuracy', 'validation_accuracy'], loc='best')
    plt.show()
def plot_loss(history,title):
    plt.title(title)
    plt.plot(history)
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train_loss', 'validation_loss'], loc='best')
    plt.show()

Hope this helps someone having the same issue as I did in the future.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM