Advanced Keras — Accurately Resuming a Training Process
This article on practical advanced Keras use covers handling nontrivial cases where custom callbacks are used.
By Eyal Zakkay, AI / Deep Learning Engineer
In this post I will present a use case of the Keras API in which resuming a training process from a loaded checkpoint needs to be handled differently than usual.
TL;DR — If you are using custom callbacks which have internal variables that change during a training process, you need to address this when resuming by initializing these callbacks differently.
When training a Deep Learning model using Keras, we usually save checkpoints of that model’s state so we could recover an interrupted training process and restart it from where we left off. Usually this is done with the ModelCheckpoint Callback. According to the documentation of Keras, a saved model (saved with
model.save(filepath)) contains the following:
- The architecture of the model, allowing to re-create the model
- The weights of the model
- The training configuration (loss, optimizer)
- The state of the optimizer, allowing to resume training exactly where you left off.
In certain use cases, this last part isn’t exactly true.
Let’s say you are training a model with a custom learning rate scheduler callback, which updates the LR after each batch:
counter variable is initialized to zero when the callback is created and keeps up with the global batch index (the
batch argument in
on_batch_endholds the batch index within the current epoch).
Let’s say we want to resume a training process from a checkpoint. The usual way would be:
Notice that the
LearningRateSchedulerPerBatch callback is initialized with
counter=0 even when resuming. When training resumes this will not recreate the same conditions that took place when the checkpoint was saved. The learning rate would restart from its initial value and that is probably not what we would want.
The correct way to do it
We saw an example of how a wrong initialization of callbacks can lead to unwanted results when resuming. There are several ways to approach this, here I will describe two:
Solution 1: Updating the variables with correct values
When dealing with simple cases, in which the callback has only a handful of updating variables, it is rather simple to overwrite the values of these variables before resuming. In our example, if we would want to resume with the correct value of
count we would do the following:
Solution 2: Saving and loading callbacks with Pickle
When our custom callbacks have many updating variables or include complex behaviors, safely overwriting each variable might be difficult. An alternative solution is to pickle the callback instance every time we save a checkpoint, then we can load this pickle when resuming and reconstruct the original callback with all its correct values.
Note: in order to pickle your callbacks they must not hold any non-serializable elements. Also, in Keras versions <2.2.3 the model itself was not serializable. This prevented the pickling of any callback, since each callback also holds a reference to the model.
In this case, resuming will look something like this:
You might have noticed that I have used a modified checkpoint callback called
ModelCheckpointEnhanced. This is because using the pickling method means that we also need the
ModelCheckpoint callback to save pickles of the relevant callbacks. An example implementation of such modified callback might look something like this:
The example above handles the case where you need to pickle only one callback, if you have multiple callbacks to save you will need to perform some minor modifications.
We saw how in some cases taking the naive approach for resuming a training process can lead to unwanted results. We saw examples of two ways to handle such cases in order to get consistent results when restarting an interrupted training process.
The examples I used are taken from my Keras implementation of the Sketch-RNN algorithm, a sequence to sequence Variational Autoencoder model for generation of sketches.
Original. Reposted with permission.
- Keras Hyperparameter Tuning in Google Colab Using Hyperas
- Introduction to Deep Learning with Keras
- Using a Keras Long Short-Term Memory (LSTM) Model to Predict Stock Prices