Useful Callbacks in Keras

A lot of useful applications of machine learning use deep learning as the backbone technology. Deep learning is the key reason behind the second wave of AI we see these days.

A deep learning engineer solving various problems might be using various deep learning libraries on a day to day basis. Tensorflow and Keras are the two popular deep learning libraries which are Python based.

Apart from building various types of neural networks, they also provide functions that eases the training and tracking of the deep learning models.

Here in this article, I will be listing out some of these functionalities which are used very frequently and comes handy for any deep learning engineer.

Model Checkpoint

During the training stage, a model’s weights are updated at every iteration. After every epoch the models performance using some metrics such as accuracy or f1-score is reported on test data. A natural requirement here is to save the best performing model weights so far for later use. This is because unlike classical methods, deep learning method does not guarentee to imporve model performance from iteration $k$ to $k+1$.

Model checkpoint is a simple callback function in Keras that we can import.

from keras.callbacks import ModelCheckpoint

As a next step, an object of the same need to be created and passed as an argument to the “model.fit” function. the following code snippet shows that.

model_checkpoint = ModelCheckpoint(
    filepath="best_model.hdf5",
    verbose=1,
    save_best_only=True,
    mode="max",
    monitor="val_accuracy",
)

Here the first parameter specifies the path to the file name where the best model weights are to be saved. This can be later loaded easily. The monitor argument mentions which metric is to be monitored inorder to save the model weight. There is other common options to turn on or off the verbose so that the function does its functionality silently of with messages. The save best only parament also ensures that only one set of weights which are corresponding to best performing on test data are saved.

In the final step, this callback object is passed with any other callback object, if any, as a list during the model training function as shown below.

history = model.fit(
    x=X_train,
    y=y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_test, y_test),
    callbacks=[model_checkpoint],
)

TensorBoard Callback

The next cool callback is Tensorboard. It allows us to track loss and other performance metrics across multiple models. It also helps to visualize model weights its histogram etc. The graphs produced by Tensorboard are interactive. We can hover over the curves and get the datapoint values displayed.

The following image shows what one typical TensorBoard window will look like.

callback tensorboard

As you can see here the plots of three different models on the same data are shown as an interactive graph in TensorBoard.

The code for Tensorboard as follows.

from tensorflow.keras.callbacks import TensorBoard

tensor_board_obj = TensorBoard(
    log_dir="./logs/fits/model_1", histogram_freq=1, write_graph=True)

history = model.fit(
    x=X_train,
    y=y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_test, y_test),
    callbacks=[tensor_board_obj])

The code used is similar to the previous case. We need to import the TensorBoard function, make an object of that and pass it as a list along with other callbacks, if any.

Early Stopper Callback

The last callback in our list is a custom one which lets you stop the model training once you achieve a desired level of performance threshold.

To do this we will have to write a child class of the parent “Callback” class. We will be setting the model training flag to false once the target performance is achieved.

Here I have taken validation accuracy as the performance metric for early stopping. The code snippets including the child class is shown below.

from tensorflow.keras.callbacks import Callback


class early_stopper(Callback):
    def __init__(self, target):
        super().__init__()
        self.target = target

    def on_epoch_end(self, epoch, logs={}):
        validation_acc = logs["val_accuracy"]
        if validation_acc > self.target:
            self.model.stop_training = True


stop_early_callback = early_stopper(0.95)


history = model.fit(
    x=X_train,
    y=y_train,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_test, y_test),
    callbacks=[stop_early_callback])

The usecase flow is the same as importing the parent class, creating a child class object, and passing it as an element of a list while model fitting.

These are the most frequent and useful callbacks (among many other callbacks provided by Keras) important to any deep learning project. I thought of putting the basic template code for the same might be useful for deep learning practisioners. If you know any other useful callbacks in Keras which come handy, let me know in the comments.

Leave a Comment

Your email address will not be published. Required fields are marked *