Keras Callbacks Explained In Three Minutes

A gentle introduction to callbacks in Keras. Learn about EarlyStopping, ModelCheckpoint, and other callback functions with code examples.



By Andre Duong, UT Dallas

 

Building Deep Learning models without callbacks is like driving a car with no functioning brakes — you have little to no control over the whole process that is very likely to result in a disaster. In this article, you will learn how to monitor and improve your Deep Learning models using Keras callbacks like ModelCheckpoint and EarlyStopping.

 

What are callbacks?

 
From the Keras documentation:

A callback is a set of functions to be applied at given stages of the training procedure. You can use callbacks to get a view on internal states and statistics of the model during training.

You define and use a callback when you want to automate some tasks after every training/epoch that help you have controls over the training process. This includes stopping training when you reach a certain accuracy/loss score, saving your model as a checkpoint after each successful epoch, adjusting the learning rates over time, and more. Let’s dive deep into some callback functions!

 

EarlyStopping

 
Overfitting is a nightmare for Machine Learning practitioners. One way to avoid overfitting is to terminate the process early. The EarlyStoppingfunction has various metrics/arguments that you can modify to set up when the training process should stop. Here are some relevant metrics:

  • monitor: value being monitored, i.e: val_loss
  • min_delta: minimum change in the monitored value. For example, min_delta=1 means that the training process will be stopped if the absolute change of the monitored value is less than 1
  • patience: number of epochs with no improvement after which training will be stopped
  • restore_best_weights: set this metric to True if you want to keep the best weights once stopped

The code example below will define an EarlyStopping function that tracks the val_loss value, stops the training if there are no changes towards val_loss after 3 epochs, and keeps the best weights once the training stops:

from keras.callbacks import EarlyStoppingearlystop = EarlyStopping(monitor = 'val_loss',
                          min_delta = 0,
                          patience = 3,
                          verbose = 1,
                          restore_best_weights = True)


 

ModelCheckpoint

 
This callback saves the model after every epoch. Here are some relevant metrics:

  • filepath: the file path you want to save your model in
  • monitor: the value being monitored
  • save_best_only: set this to True if you do not want to overwrite the latest best model
  • mode: auto, min, or max. For example, you set mode=’min’ if the monitored value is val_loss and you want to minimize it.

Example:

from keras.callbacks import ModelCheckpointcheckpoint = ModelCheckpoint(filepath,
                             monitor='val_loss',
                             mode='min',
                             save_best_only=True,
                             verbose=1)


 

LearningRateScheduler

 

from keras.callbacks import LearningRateSchedulerscheduler = LearningRateScheduler(schedule, verbose=0) # schedule is a function


This one is pretty straightforward: it adjusts the learning rate over time using a schedule that you already write beforehand. This function returns the desired learning rate (output) based on the current epoch (epoch index as input).

 

Other Callbacks functions

 
Along with the above functions, there are other callbacks that you might encounter or want to use in your Deep Learning project:

  • History and BaseLogger: callbacks that are applied automatically to your model by default
  • TensorBoard: This is hands down my favorite Keras callback. This callback writes a log for TensorBoard, which is TensorFlow’s excellent visualization tool. If you have installed TensorFlow with pip, you should be able to launch TensorBoard from the command line: tensorboard — logdir=/full_path_to_your_logs
  • CSVLogger: This callback streams epoch results to a csv file
  • LambdaCallback: This callback allows you to build custom callback

 

Conclusion

 
In this article, you have learned the main concept of callbacks in Keras and the callback functions. Keras document has a very comprehensive page on callbacks that you should definitely check out: http://keras.io/callbacks/

Leave comments if you have any suggestions on how to improve this post. Follow me up at Medium or connect with me on LinkedIn for more quality content!

 
Bio: Andre Duong is a sophomore CS Undergrad @ UT Dallas. His interests include Machine Learning, Data Science, and Software Development.

Original. Reposted with permission.

Related: