简体   繁体   中英

Is there a way to stop training in the middle of an epoch with tensorflow?

Just wondering if there is a way to save the highest accuracy and lowest loss in the middle of an epoch and using that as the score moving forward for the next epoch. Normally my data is maxed out at 43.56% accuracy but I've seen it go all the way up to above 46% in the middle of an epoch. Is there a way I can stop the epoch at that point and use that for the score to beat moving forward?

Here's the code that I'm running right now

import pandas as pd
import numpy as np
import pickle
import random
from skopt import BayesSearchCV
from sklearn.neural_network import MLPRegressor
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, LSTM, Bidirectional, SimpleRNN, GRU
from tensorflow.keras.wrappers.scikit_learn import KerasRegressor
from tensorflow.keras import layers
import tensorflow_docs as tfdocs
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
import tensorflow_docs.modeling
from tensorflow import keras
import warnings
warnings.filterwarnings("ignore")
warnings.filterwarnings('ignore', category=DeprecationWarning)

train_df = avenues[["LFA's", "Spend"]].sample(frac=0.8,random_state=0)
test_df = avenues[["LFA's", "Spend"]].drop(train_df.index)
train_df = clean_dataset(train_df)
test_df = clean_dataset(test_df)
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)
train_stats = train_df.describe()
train_stats = train_stats.pop("LFA's")
train_stats = train_stats.transpose()
train_labels = train_df.pop("LFA's").values
test_labels = test_df.pop("LFA's").values
normed_train_data = np.array(norm(train_df)).reshape((train_df.shape[0], 1, 1))
normed_test_data = np.array(norm(test_df)).reshape((test_df.shape[0], 1, 1))
model = KerasRegressor(build_fn=build_model, epochs=25, 
                                   batch_size=1, verbose=0)
gs = BayesSearchCV(model, param_grid, cv=3, n_iter=25, n_jobs=1,
                               optimizer_kwargs={'base_estimator': 'RF'},
                               fit_params={"callbacks": [es_acc, es_loss, tfdocs.modeling.EpochDots()]})
try:
     gs.fit(normed_train_data, train_labels)
except Exception as e:
     print(e)

Try to use train_on_batch instead of fit . In this way, you can control inside your epoch to stop after each batch you would like ( Although as a machine learning point of view, I have doubts if it is a good idea, and you'll get the less generalized model at the end).

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM