简体   繁体   中英

DREAM - Neural Network not Converging - Loss down and up

I'm using this DREAM for next basket prediction: https://github.com/LaceyChen17/DREAM

Its pretty easy and straightfoward, a few tweaks to run (minor errors from constants) but my Loss is crazy after i try to Train it.

I'm trying to simulate its output using the proposed dataset (from instacart): https://www.instacart.com/datasets/grocery-shopping-2017

But the data doesn't seem to converge. I tried to change the LR from 0.1 to 0.001, tweak the CLIP or dropout but nothing happens positively. My loss keeps going down and then up again like crazy.

Im trying to study this network, and after i run it a first time, i want to work it from there, but right now i can't seem to be able to debug its problem.

Here is a sample of my config:

DREAM_CONFIG = {'basket_pool_type': 'max', # 'avg'
                'rnn_layers': 3, # 2, 3
                'rnn_type': 'LSTM',#'RNN_TANH',#'GRU',#'LSTM',# 'RNN_RELU',
                'dropout': 0.5,
                # 'num_product': 49688 + 1, # padding idx = 0
                'num_product': 49688 + 1 + 1, 
                # 49688 products, padding idx = 0, none idx = 49689, none idx indicates no products
                'none_idx': 49689,
                'embedding_dim': 64, # 128 
                'cuda': False, # True,
                'clip': 20, # 0.25
                'epochs': 100,
                'batch_size': 256,
                'learning_rate': 0.0001, # 0.0001
                'log_interval': 1, # num of batchs between two logging
                'checkpoint_dir': DREAM_MODEL_DIR + 'reorder-next-dream-{epoch:02d}-{loss:.4f}.model',
                }

Any insights?

There may be additional things to try for improving the performance of the training, however, I would at least recommend using the "ModelCheckpoint" class to save the best validation weights.

If you are using Keras then you can reference this link for more information.

Keras Callbacks

You'll want to allocate some of the training data to the validation set. This callback calculates the loss for the validation set and saves the weights each time it improves. After training you can load the best weights to prevent the model from overfitting. This will at least prevent your model from getting worse during the training process and from there you can troubleshoot by making additional tweaks.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM