简体   繁体   中英

Why does my training loss oscillate while training the final layer of AlexNet with pre-trained weights?

I am working on texture classification and based on previous works, I am trying to modify the final layer of AlexNET to have 20 classes, and train only that layer for my multi class classification problem. I am using Tensorflow-GPU on an NVIDIA GTX 1080, Python3.6 on Ubuntu 16.04. I am using the Gradient Descent Optimiser and the class Estimator to build this. I am also using two dropout layers for regularization. Therefore, my hyper parameters are the learning rate, batch_size, and weight_decay. I have tried using batch_size of 50,100,200,weight_decays of 0.005 and 0.0005, and learning rates of 1e-3,1e-4,and 1e-5. All the training loss curves for the above values follow similar trends.

My training loss curve does not monotonically decrease and instead seems to oscillate. I have provided a tensorboard visualization for learning rate=1e-5, weight decay=0.0005, and batch_size=200.

Please assist in understanding what went wrong and how I could possibly rectify it. The Tensorboard Visualization for the case I specified

  # Create the Estimator
  classifier = tf.estimator.Estimator(model_fn=cnn_model)
  # Set up logging for predictions
  tensors_to_log = {"probabilities": "softmax_tensor"}
  logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=10)
 # Train the model
  train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": train_data},y=train_labels,batch_size=batch_size,num_epochs=None,shuffle=True)
  classifier.train(input_fn=train_input_fn, steps=200000, hooks=[logging_hook])
  # Evaluate the model and print results
  eval_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": eval_data},
  y=eval_labels,
  num_epochs=1,
  shuffle=False)
  eval_results = classifier.evaluate(input_fn=eval_input_fn)
  print(eval_results)

#Sections of the cnn_model
 #Output Config
 predictions = { "classes": tf.argmax(input=logits, axis=1),# Generate predictions (for PREDICT and EVAL mode)
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")}  # Add `softmax_tensor` to the graph. It is used for PREDICT and by the `logging_hook`.
if mode == tf.estimator.ModeKeys.PREDICT:
  return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

 # Calculate Loss (for both TRAIN and EVAL modes)
  onehot_labels = tf.one_hot(indices=tf.cast(labels,tf.int32),depth=20)
  loss = tf.losses.softmax_cross_entropy(onehot_labels=onehot_labels, logits=logits)


#Training Config
  if mode == tf.estimator.ModeKeys.TRAIN:
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
    tf.summary.scalar('training_loss',loss)
    summary_hook = tf.train.SummarySaverHook(save_steps=10,output_dir='outputs',summary_op=tf.summary.merge_all())
    train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
    return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op,training_hooks=[summary_hook])



# Evaluation Metric- Accuracy
eval_metric_ops = {"accuracy": tf.metrics.accuracy(labels=labels, predictions=predictions["classes"])}
print(time.time()-t)
tf.summary.scalar('eval_loss',loss)
ac=tf.metrics.accuracy(labels=labels,predictions=predictions["classes"])
tf.summary.scalar('eval_accuracy',ac)
evaluation_hook= tf.train.SummarySaverHook(save_steps=10,output_dir='outputseval',summary_op=tf.summary.merge_all())
return tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops,evaluation_hooks=[evaluation_hook])

Are you selecting your mini-batches randomly? It looks like you have a high variance across your mini-batches which leads to a high variance of the loss at different iterations. I assume the x-axis in your plot is iterations and not epochs and the training data provided every ~160 iterations is harder to predict which leads to the periodic drop in your loss curve. How does your validation loss behave?

Possible solutions/ideas:

  • try randomizing your training data selection in a better way
  • Check your training data for mislabeled examples

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM