Callbacks in Keras

Learn via video courses
Topics Covered

Overview

In this article, we will study callbacks in Keras. While training Deep Learning (DL) models we want to analyze the model weights, save model weights after certain intervals, schedule the learning rate in accordance with the epoch and apply fault tolerance so that at any point in time, the model can be retrained and deployed to make predictions. In all these scenarios, callback plays a very crucial role. Callback enables us to perform all these tasks easily by incorporating a few lines of code. The article focuses on the implementation of callback in Keras.

Note: This article is based on the assumption that the reader has basic knowledge of training Deep Learning (DL) models using Keras.

What are callbacks in Keras?

During the training phase of the model i.e., when the model weights are being optimized and are adjusted to solve the particular problem statement, callback gives us the capability to have full control over the training steps. Callback gives us the insight to monitor the model statistics, save the best model, resume/stop/save the model training based on the model statistics, and have fine-grain control over the training steps. Some of the major functions of callback in Keras are shown below :

  1. To analyze the internal statistics of models such as the distribution of model weights, biases, etc.
  2. To save and restore the model weights periodically.
  3. Integration with tensorboard or experimental platform to analyze the model parameter statistically.
  4. To periodically save and restore the model weights.
  5. For generating the training/testing/validation logs
  6. Scheduling the Learning Rate (lr) hyperparameter while fitting the data in the model.

Why CallBack?

Model Training is implemented by fitting the data points to the model with the sole objective of minimizing the loss function. The model training process is mainly composed of epochs, i.e., the number of times the training sample will be fitted into the model and error will be backward propagated throughout the model. Every single epoch can be broken down into smaller steps for Train/ Test/ Val set respectively:

  1. On Batch Begin: It is a pre-defined function that can be overloaded by inheriting the base class known as a callback. It is invoked at the beginning when the model is to be trained on new batches of data.
  2. On Batch End: It is a pre-defined function that can be overloaded by inheriting the base class known as a callback. It is invoked at the end i.e when the model is about to fetch the new batch for training.
  3. On Epoch Begin: It is a pre-defined function that can be overloaded by inheriting the base class known as a callback. It is invoked at the beginning of the epoch.
  4. On Epoch End: It is a pre-defined function which can be overloaded by inheriting the base class known as a callback. It is invoked at the end of the epoch.
  5. On Predict Batch Begin: It is a pre-defined function that can be overloaded by inheriting the base class known as a callback. It is invoked at the beginning when the model will be predicting the batches of data.
  6. On Predict Batch End: It is a pre-defined function that can be overloaded by inheriting the base class known as a callback. It is invoked at the end when the model has predicted the batches of data.
  7. On Predict Begin: It is a pre-defined function that can be overloaded by inheriting the base class known as a callback. It is invoked at the beginning when the model is going to predict one sample.
  8. On Predict End: It is a pre-defined function that can be overloaded by inheriting the base class known as a callback. It is invoked at the end when the model is predicated on one sample.

To have fine grain control in all these steps we need to implement the callbacks in Keras. Callbacks also enable us to use third-party experimental metric logging platforms such as wandb as well as Neptune. Callbacks can be categorized into two categories:

  1. Legacy Callbacks
  2. Custom Callbacks

Custom Callback overwrites the existing callback class by extending the respective class and methods whereas Legacy Callback comes with the Keras API and it can be implemented by initiating the object of the legacy callback class. In the section below we will be discussing callback in Keras callback in detail.

Legacy Callbacks Offered in Keras

Keras is a high-level API i.e.most of the implementation is wrapped as classes and functions. Keras provides us with many defaults or we can say legacy callbacks that come along with the Keras. Some of the legacies of Keras callbacks are explained below:

BackupAndRestore Callback

BackupAndRestore callback in Keras is implemented to recover the model from any interruption during fitting the dataset i.e. during the training process. It saves temporary model data into save at the end of the epoch and overwrites continuously, so, at any point in time, one checkpoint is available. When the model training process gets interrupted, the saved model checkpoint is restored in the model. It is used as fault tolerance.

Syntax

Arguments

  1. backup_dir: Path where the model checkpoint is to be saved. Type: String
  2. save_freq: This signifies after which epoch the checkpoint will be saved. Type: String epoch: Checkpoint is saved after the end of each epoch Type: Integer 1-n: Checkpoint is saved after the specified epoch.
  3. delete_checkpoint: By saving a checkpoint to back up the training state, this BackupAndRestore callback operates. After the training phase is complete, the checkpoint will be destroyed if the delete checkpoint is set to True. If you want to save the checkpoint for later use, set it to False. Type: Boolean

RemoteMonitor Callback

To understand the RemoteMonitor callback in Keras, let's take up the scenario where writing one research paper on the comparison of the model parameters. There are five people associated with this paper and they are geographically apart i.e. in different time zones each one of them has specialization i.e. one of them is having expertise in model weight and biases, another is expertise in model activation, and so on. Since all of the team members are spread geographically and time zones differ, sharing the results will be efficient and it will be a tedious task if we use google drive or email. In this scenario, Remote monitoring plays a very significant role. It enables us to monitor and analyze the model training parameter in real time. We can provision one monitoring server and connect the training scripts with the monitoring server via. Keras callback was known as RemoteMonitor. Keras RemoteMonitor is a POST request. We can assume that it is an API that gets triggered after each training epoch and posts the logs/data in JSON or File format to the monitoring server.

Syntax

Arguments

  1. root: It is the root URL/domain of the server where the request will be made. Type: String

  2. path: It is the path after the URL where logs/data of the training steps will be posted. Type: String

  3. field: Path of the folder where log/data is saved (data which is to be sent) Type: String

  4. headers: None, if authentication/custom header required Type: String

  5. send_as_json: If the request is to be sent as JSON or file Type: String

TerminateOnNaN Callback

Under certain extreme scenarios, there are possibilities that loss or other metrics which we have specified during model compilation during the training process of the model this metric can tend to explode i.e. the loss is not a number (NAN) on training or testing dataset. So in this case we should stop the training process immediately and analyze the model for a possible solution. The TerminateOnNan callback in Keras is specially designed and developed to monitor the loss of the model.

Syntax

Arguments

It accepts no arguments

ReduceLROnPlateau Callback

ReduceLROnPleateau callback in Keras is used to reduce the learning rate of the model when it has reached the plateau which optimizes the model weights and biases. It monitors the specified metric and if there is no improvement after a certain number of epochs i.e. patience it reduces the learning rate of the model by some factor. new_lr = factor*old_lr

Syntax

Arguments

  1. monitor: Metric which we want to be monitored Type: String
  2. patience: Number of epochs the model will wait to reduce the learning rate if the metric stops improving Type: Integer
  3. factor: Times by which the current learning rate will be reduced Type: Float
  4. mode: Min, Max and Auto depending upon the metric we are monitoring. In the "min" and "max" modes, the training will end when the quantity being watched has finished decreasing or increasing. Whereas in the "auto" mode, the direction is automatically determined based on the name of the quantity being monitored. Type: String
  5. min_lr: Threshold to monitor the changes. It enables us to focus on significant changes only. Type: Float
  6. cooldown: Number of epochs the model will weigh after reducing the learning rate to start normal functioning. Type: Integer
  7. min_lr: Least value on the learning rate. Type: float

Tensorboard Callback

Tensorboard is a visualization tool that enables you to evaluate many features of your model, including its weights, biases, and gradients, as well as how they changed throughout training (i.e., across epochs). Additionally, you can visualize classes in a multidimensional space, model performance over time, and so forth. It is a set of functions to be applied at various phases of the training provided by Keras in the form of a callback. They may be used to examine the model's internals and statistics both during training and after.

Syntax

Arguments

  1. log_dir: Path where logs will be saved. Type: String
  2. histogram_freq: The rate (in epochs) at which weight histograms for the model layers should be calculated. Histograms won't be computed if set to 0. Histogram visualizations require the specification of validation data (or a split). Type: String
  3. write_graph: Whether or not to use TensorBoard to display the graph. When the write graph is set to True, the log file may get extremely huge. Type: Boolean
  4. write_images: Whether or not to use TensorBoard to display the images. Type: Boolean
  5. write_steps_per_second: Tensorboard should record the training steps per second. Epoch and batch frequency logging is supported. Type: Boolean
  6. update_freq: Frequency at which Tensorflow logs are updated in the training phase. Type: batch. TensorBoard should be updated with the losses and metrics after the completion of every batch. Type: epoch. TensorBoard should be updated with the losses and metrics after the completion of every epoch. Type: Integer. TensorBoard should be updated with the losses and metrics after the completion of every n-batch.
  7. profile_batch : To sample the features of the compute, profile the batch(es). profile batch needs to be an integer or tuple of integers that is not negative. A range of batches to profile is represented by a pair of positive numbers. Profiling is deactivated by default. Type: Integer/ Tuple(integer)
  8. embeddings_freq: The number of times per epoch that embedding layers will be seen. Embeddings won't be seen if the value is set to 0. Type: Integer
  9. embeddings_metadata: Dictionary that associates the filename of a file to save the embedding layer's metadata with its name. A single filename can be given if the same metadata file is to be used for all embedding layers. Type: Dictionary

EarlyStopping Callback

This is the type of callback in Keras that enables the end user to monitor the specific model metric/metrics. If the metrics stop improving the training the model is stopped automatically and the best weights are restored. The sole objective of this callback is to stop the model training if there is no improvement in the selected metric i.e, the model is not converging.

Syntax

Arguments

  1. monitor: Metric name which we are monitoring. Type: String
  2. min_delta : Threshold for change, if change is below than the Min_delta it will be considered as no change. Type: Float
  3. patience: Number of epochs the model will wait to see if the monitored metric is improving before stopping the training. Type: Integer
  4. mode: Min, Max and Auto depending upon the metric we are monitoring. In the "min" and "max" modes, the training will end when the quantity being watched has finished decreasing or increasing. Whereas in the "auto" mode, the direction is automatically determined based on the name of the quantity being monitored. Type: String
  5. baseline: Training will stop if the model-monitored metric does not improve upon the baseline value. Type: Float
  6. restore_best_weights: If this argument is set to True then, the model will restore the weights upon completion following the metric we are specified to be monitored and if this argument is set to False then the last epoch weight is restored in the model. In case the monitored metric value does not improve on baseline then training will continue for the patience epochs and the epoch number which has the best metric value will be restored in the model. Type: Boolean

ModelCheckpoint Callback

ModelCheckpoint callback in Keras is implemented to save the model weights at specified time intervals. It enables the researcher or developer to load the model weight later on and continue the training or prediction.

Syntax

Arguments

  1. filepath: Path where the model checkpoint will be saved. Type: String
  2. monitor: Metric name which we are monitoring. Type: String
  3. save_best_only: If this is set to True then only “best” models are saved, and the most recent best model based on quantity monitored is not overwritten. If the filepath doesn’t include formatting options like “epoch,” then each new, better model will replace the filename. Type: Boolean
  4. save_weights_only: If it is set to True, then only model weights are saved, else the full model is saved. Type: Boolean
  5. mode: Min, Max and Auto depending upon the metric we are monitoring. In the "min" and "max" modes, the training will end when the quantity being watched has finished decreasing or increasing. Whereas in the "auto" mode, the direction is automatically determined based on the name of the quantity being monitored. Type: String
  6. save_freq: This signifies after which epoch the checkpoint will be saved Type: String epoch: Checkpoint is saved after the end of each epoch Type: Integer 1-n: Checkpoint is saved after the specified epoch.
  7. option: tf.train is optional. If save weights only is true, the CheckpointOptions object or the optional tf.saved model will be returned. If save weights only is false, the saveOptions object is created.
  8. initial_value: Initial floating-point “best” value of the measure to be tracked. Useful only if the save best value is set to True. If the value of the metric which we are monitoring is better than this previous value then the model weights are only overwritten. Type : Float

LearningRateScheduler Callback

LearningRateScheduler callback in Keras gets executed before the start of the epoch. The learning rate (lr) of the model gets initialized by the function (returns the Learning Rate (lr) for the current epoch) we specify in the callback and is updated in the optimizer internally with respect to the current epoch.

Syntax

Arguments

  1. schedule: Function, which accepts epoch and current learning rate as an argument and returns new learning rate. Type: Function Object

Other Callbacks

CSVLogger Callback

CSVLogger callback in Keras is used to save the training logs i.e, metrics which are specified during the model compilation phase in the CSV file, which can be used later to analyze the model statistics.

Syntax

Arguments

  1. filename: Name of the file where the log is to be saved. Type: Sting
  2. **separator **: Separator to separate the log values. Type: String
  3. append: Whether to append the existing file or create the new log file. Type: Boolean

LambdaCallback

LambdaCallback callback in Keras is the short-hand technique of the Custom callbacks in Keras. It enables the creation of the unnamed function which can be called during the model.fit/evaluate/predict phase.

It can be invoked in any of these six mentioned phases/stages.

  1. on_epoch_begin: Invokes at the beginning of each epoch and accepts two arguments epoch and log.
  2. on_epoch_end: Invokes at the end of each epoch and accepts two arguments epoch and log.
  3. on_batch_begin: Invokes at beginning of the each batch and accepts two arguments epoch and log.
  4. on_batch_end: Invokes at the end of each batch and accepts two arguments epoch and log.
  5. on_train_begin: Invokes at beginning of the training process and accepts one argumentslog.
  6. on_train_end: Invokes at beginning of the training process and accepts one argumentslog.

Syntax

Arguments

  1. stage: Name of the stage where we want to invoke the callback function. Type: Sting
  2. separator: Separator to separate the log values. Type: String
  3. append: Whether to append the existing file or create the new log file. Type: Boolean

BaseLogger Callback

BaseLogger callback in Keras is implemented by default in every model. This callback takes the list of the metrics which are being considered to build the model with respect to the problem statement and returns the array of the average of these metrics over each epoch.

Syntax

Note: It is implemented by default. If we want to exclude some metric i.e. we don't metric to be averaged over an epoch in that scenario we implement this callback explicitly.

Arguments

  • stateful_metrics: List of the metric names that should not be averaged over an epoch. The specified metric in the list will be logged in as it is whereas the remaining will be averaged and logged in on_epoch_end.

Implementing Callbacks in Keras

In this section, we will develop a better understanding of implementing custom callbacks in Keras along with the legacy callbacks discussed above. existing function and write the code snippets to achieve our goal.

The custom callbacks in Keras can be implemented by inheriting the base class named tf.keras.callbacks.Callback and overriding the following functions:

  1. On Epoch End
  2. On Epoch Begin
  3. On Predict (Train/Test) Batch Begin
  4. On Predict (Train/Test) Batch End
  5. On Predict Begin
  6. On Predict End
  7. On (Train/Test) Batch Begin
  8. On (Train/Test) Batch End

The below code is an in-depth explanation of the callback in Keras. The objective of this article is to explain the callbacks in Keras implementation so i will be focusing on that section of the code snippets.

Code

Step 1. Importing required libraries and Tensorflow ext

The first step is to import the required libraries. The snippets are used to import the required libraries for Keras so that the classes and functions can be implemented in our code along with the tensorboard extension.

Step 2. Data Preprocessing

For the explanation purpose, I have used the MNIST-Digit dataset which consists of 60,000 greyscales (composed of only one channel) images of handwritten digits from 0-10. The dataset is split into two sections i.e. Train set which consists of 50,000 images and the Test set which consists of 10,000 images. I have also preprocessed the dataset by normalization and converting the labels into categorical values as well as reshaping the dataset into 28,28,1 shape because we are going to implement the Convolution Neural Network 2D model. The dataset is loaded from the Keras MNIST-Digit dataset which is already divided into training and testing sets. The dataset is scaled by dividing each pixel by 255 so that the value of each pixel is between 0 and 255 and reshaped into 28*28*1. The below code snippets depict the preprocessing steps of the MNIST Digit dataset.

Output

Step 3. Model Creation

Now after preprocessing the dataset, it's time for the creation of the model which we will train. For sake of simplicity, I have constructed a simple 2D-Convolutional Neural Network (CNN) model with 1 dense layer with 10 neurons and activation function of Softmax because we have 10 classes two Conv2d with filters with 64 and 128 filters respectively with a kernel size of 33, along with the max pooling layer with a pool size of 22 and dropout layer with the probability of 0.5. The below code snippets depict code for creating the model as discussed above.

After model creation, we are displaying the summary in the output cell. The model summary consists of parameters trainable as well as the non-trainable, layer name, and the output shape of the respective layer. The summary of the created model is shown below:

**Output **

Step 4. Callbacks

In this section, we will implement the callback in Keras with the MNIST-digit classifier model which we have created in Step 3. We will implement the following callbacks:

  1. Custom callback
  2. EarlyStopping
  3. ModelCheckpoint
  4. CSVLogger
  5. Tensorboard
  6. LearningRateScheduler
  7. LambdaCallback

The list of callbacks is mentioned according to the sequence of implementation in the code. The below code snippets is a function that accepts the epoch number and learning rate of the previous epoch of the model and returns the updated learning rate for the current epoch as per our requirement. The below function demonstrates the Learning Rate (lr) Scheduler callback. It returns the same Learning Rate (lr) value if the epoch number is even else it returns the manipulated Learning Rate (lr) value which is shown below.

The below code snippets inherit the base class tf.keras.callbacks.Callback and override the functions of the base class as per our requirement. For demonstration purposes, I have added one printing statement and evaluated the model on the test dataset. The evaluation metric is added to the training logs which can be seen during the model training process.

The below snippets are the compilation of all the callbacks into an array which is passed to the model.fit as an argument. There are many ways to do it. One of the ways is that we can directly add the individual object of the callbacks which we want to implement while fitting the dataset into the model but the method which is shown below is the most efficient way to accomplish our task.

Step 5. Compile and Training the Model

In this section, we are going to compile our model. For compiling the model we need to specify the loss function, optimizer and metric. I have used categorical_crossentropy as our loss function because our dataset is multilabel and we have transformed the dataset label into the categorical Variable. Adam was selected as the optimizer to backward propagate the error. Adam is an extension of the Stochastic Gradient Descent and a combination of the Root Mean Square Propagation (RMSProp) and Adaptive Gradient Algorithm (AdaGrad). For metric, we have used accuracy for simplicity you can use any metric based on your problem statement. The below snippets depict the code for model compilation.

After the successful compilation of the model, our final step is to train the model. The dataset will be split into two sets i.e training and testing sets. The argument validation_split denotes the ratio by which the dataset will be split, in our case, it is 0.1 it signifies that ten per cent of the dataset will be used for testing, and the remaining ninety per cent will be used for training the model with a batch size of 128. The below snippets depict the code for model training.

Finally, while executing the training code snippets we will see the outputs as shown below image which has the printing statement which we included in the callback implementation i.e. LearningRateScheduler, Custom, and lambda Callback.

training-output-with-callback-in-keras

In some of the callbacks i.e CSVLogger, Tensorboard, and ModelCheckpoint we are saving the logs of the training. So the below image describes the logs/files saved by callback implementation in Keras.

log-data-via-callback

In the callback, we have implemented the Tensorboard visualization tool as a callback. The below image depicts the Tensorboard output after the model is trained.

tensorboard-with-callback-in-keras

Conclusion

In this article, we have studied the fundamentals of Keras callbacks (Legacy and Custom). The key takeaways from this article are :

  • Callback gives us an upper hand while training any Deep Learning model we can leverage by controlling the epoch.
  • Callback enables us to log the weights after each successful epoch and we analyze the weights distribution of the model.
  • Callback gives us the ability to visualize the model hyperparameter parameter in the form of graphs and share.
  • Callbacks give us the option to schedule the learning rate of the model with respect to the epoch or training phase.
  • Callback in Keras can also act as a fault-tolerant procedure and we can resume the training again from the point (epoch) where it stopped