简体   繁体   中英

Hyper-parameter tuning and Over-fitting with Feed-Forward Neural Network - Mini-Batch Epoch and Cross Validation

I am looking at implementing a hyper-parameter tuning method for a feed-forward neural network (FNN) implemented using PyTorch . My original FNN , the model is named net , has been implemented using a mini-batch learning approach with epochs:

#Parameters
batch_size = 50 #larger batch size leads to over fitting
num_epochs = 1000 
learning_rate = 0.01 #was .01-AKA step size - The amount that the weights are updated during training
batch_no = len(x_train) // batch_size 

criterion = nn.CrossEntropyLoss()  #performance of a classification model whose output is a probability value between 0 and 1
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    if epoch % 20 == 0:
        print('Epoch {}'.format(epoch+1))
    x_train, y_train = shuffle(x_train, y_train)
    # Mini batch learning - mini batch since batch size < n(batch gradient descent), but > 1 (stochastic gradient descent)
    for i in range(batch_no):
        start = i * batch_size
        end = start + batch_size
        x_var = Variable(torch.FloatTensor(x_train[start:end]))
        y_var = Variable(torch.LongTensor(y_train[start:end]))
        # Forward + Backward + Optimize
        optimizer.zero_grad()
        ypred_var = net(x_var)
        loss =criterion(ypred_var, y_var)
        loss.backward()
        optimizer.step()

I lastly test my model on a separate test set.

I came across an approach using randomised search to tune the hyper-parameters as well as implementing K-fold cross-validation ( RandomizedSearchCV ).

My question is two-fold(no pun intended!) and firstly is theoretical: Is k-fold validation is necessary or could add any benefit to mini-batch feed-forward neural network? From what I can see, the mini-batch approach should do roughly the same job, stopping over-fitting.

I also found a good answer here but I'm not sure this addresses a mini-batch approach approach specifically.

Secondly, if k-fold is not necessary, is there another hyper-parameter tuning function for PyTorch to avoid manually creating one?

  • k-fold cross validation is generally useful when you have a very small dataset. Thus, if you are training on a dataset like CIFAR10 (which is large, 60000 images), then you don't require k-fold cross validation.
  • The idea of k-fold cross validation is to see how model performance (generalization) varies as different subsets of data is used for training and testing. This becomes important when you have very less data. However, for large datasets, the metric results on the test dataset is enough to test the generalization of the model.
  • Thus, whether you require k-fold cross validation depends on the size of your dataset. It does not depend on what model you use.
  • If you look at this chapter of the Deep Learning book (this was first referenced in this link ):

Small batches can offer a regularizing effect (Wilson and Martinez, 2003), perhaps due to the noise they add to the learning process. Generalization error is often best for a batch size of 1. Training with such a small batch size might require a small learning rate to maintain stability because of the high variance in the estimate of the gradient. The total runtime can be very high as a result of the need to make more steps, both because of the reduced learning rate and because it takes more steps to observe the entire training set.

  • So, yes, mini-batch training will have a regularizing effect (reduce overfitting) to some extent.
  • There is no inbuilt hyperparameter tuning (at least at the time of writing this answer), but many developers have developed tools for this purpose ( for example ). You can find more such tools by searching for them. This question has answers which list a lot of such tools.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM