How to Save a PyTorch Model?
Overview
Training of large deep neural networks is a very computationally expensive and time-consuming process. Furthermore, we only want to train them over again once fully trained. So to this end, this article uses code examples to explain how to save a model in PyTorch that is entirely (or partially) trained on a dataset. We can hence load the saved models for inference without training them repeatedly every single time.
Introduction
- Training of deep neural networks is one of the most time and resource-consuming steps in any model-building pipeline. Moreover, with the advent of very large neural models containing as high as billions of parameters (for example GPT- 3 contains 175 billion parameters), training time for these state-of-the-art models is also growing exponentially. Consequently, many computational resources go into training these models to a stage where they can produce meaningful results.
- This means that once the model training step is complete, we would need to be able to save the trained model in some way so that we are able to load it for later use during inference or deployment. But for an easy and efficient way to save trained models, we must train them again every time we need to use them for inference or other use.
- In addition to this, there are also other reasons why saving models could come in handy - for example, we might want to save our training model somewhere in the middle during training due to reasons like compute limitations, etcetera, and then load them back in the same state we left them off at (from the same checkpoint) and train it further for more epochs.
- As an instance, Kaggle, Google Colab, and most of the other free cloud services hosting notebook platforms have specific time outs and idleness limits exceeding which the notebook gets disconnected from the runtime. In addition, the notebook also gets disconnected or interrupted once a limited time is reached. This is another scenario where saving a fully or partially trained model becomes indispensable.
That said, we will now move on to learn different ways to save a model in PyTorch.
Saving & Loading Model for Inference
First, we will look at three different ways of saving and loading trained PyTorch models for inference. By inference, it means that once the model is fully trained and performs well according to some criterion, we want to save it and load it back to use it to make predictions later. Let us first learn about the mechanics of a common way to do that - state_dict in PyTorch.
What is a State_dict?
- As we are already familiar, in PyTorch, the optimizable parameters (or the tunable parameters), the weights, and biases of any model can be accessed using model.parameters().
- A state_dict is simply a python dictionary object that can store information about all the learnable model parameters by storing in the dictionary as keys to each model layer and their parameter tensors as the corresponding values.
- All the learnable model parameters, including layers like batchnorm, nn.Parameter, etcetera, are saved in the model's state_dict.
- In the same way, optimizer objects also have their own state_dict containing information about the optimizer's state (in the case of state-maintaining optimizers like Adam, etc.) as well as the corresponding hyperparameters like the learning rate, momentum in the case of Stochastic Gradient Descent, etc.
Let us first look in code how state_dict exactly manifests itself after which we will leverage it to learn one way of saving and loading models in PyTorch.
output
That is what state_dict looks like. We will now leverage it to save and load a trained model in PyTorch - this is one way of saving and loading trained PyTorch models for inference.
We could use torch.save to save the model's state_dict in a specified PATH, like so:
And then load the state_dict back using torch.load like so:
Output
This is the simplest way to load and use an already trained model for inference sometime later.
We will now learn another way how to save a model in PyTorch, but before that, let us briefly understand a closely related concept, namely "the pickle module in Python".
What is Pickle in Python?
In the simplest terms, the pickle module in Python is used for serializing and consequently de-serializing a Python object structure into a binary serialization format and vice versa.
Pickling is converting a Python object hierarchy into a byte stream. Consequently, "unpickling" is its inverse operation, wherein a byte stream (from a binary file or bytes-like object) is converted back into an object hierarchy.
There is a way to save the entire model instance (object) in PyTorch leveraging Python's pickle module. We could use the torch.save and torch.load to save the model object altogether, like so:
output
This way, the entire module (the model which is an instance of torch.nn.module) is saved using Python's pickle module.
This approach has a bottleneck which is that the serialized data (that is stored in the pickle module) is bound to the specific classes and the exact directory structure used when the model is saved.
This is so because pickle does not save the model class itself but rather saves a path to the file containing the class, which is used while loading the model using torch.load. Unfortunately, this also means that the code can break during loading in various ways when used in other projects or after code refactors. Therefore, the entire directory structure needs to be the same - there's dependency during loading.
We will learn about another format called TorchScript Format for saving and loading models in PyTorch.
Export/Load Model in TorchScript Format
TorchScript format is an intermediate representation of a PyTorch model that can be run in Python and a high-performance environment like C++. TorchScript format of saving models is recommended when models are to be used for scaled inference and deployment.
Unlike the pickle format, TorchScript allows us to load the exported model and run inference without defining the model class.
To be able to save trained models in TorchScript format, we first export the trained model to the required format using torch.jit.script, and then save them like so,
And consequently, load the model using torch.jit.load, like so :
With this, we have learned three ways to save a trained model in PyTorch and use it later for inferential purposes.
Saving & Loading a General Checkpoint
This section will look at slightly different ways of saving PyTorch models for use cases, like saving an intermediate checkpoint that could be used later to resume model training.
- While checkpointing models, we want to be able to save more than just the model's state_dict (or parameter weights of tunable parameters).
- To continue training, it is also important to save the optimizer's state (via the optimizer's state_dict) to use the running statistics maintained by the optimizer instance to update the model weights as the model trains. But for these running stats, a lot of noise shall be introduced in the converging trajectory of the optimizable parameters during training.
- Other than that, one might also want to save the epoch the model was left off at, the training loss recorded during that epoch, and any external torch.nn.Embedding layers so on and so forth.
- To save the components above, we could organize them in a python dictionary object and use torch.save() to serialize the dictionary very similar to how we save our model instance in the pickle format as discussed earlier.
(The common PyTorch convention is to save such checkpoints with the .tar file extension.)
- To load the saved checkpoint back, we first need to initialize both the model and the optimizer instances and then load the saved dictionary locally using torch.load(). The loaded dictionary could now be queried to access the saved information like so:
Output
Note that in the above code snippet, the epoch and loss take values 0 and 0.0 respectively, which does not make sense when seen from a practical use case point of view. Here, the model has not been trained for any number of epochs, so we do not have any track of the loss value, so those values are chosen.
In practice, models are saved only after training for a certain number of epochs, so these variables shall take some realistic values accordingly.
Saving Multiple Models in One File
All good so far! We have learned different ways of saving and loading one single-trained PyTorch model. However, we could save multiple models, similar to how we do it for a single model. So let us look at how saving multiple models works in PyTorch.
-
In PyTorch, we can save more than a model, that is, a model composed of multiple torch.nn.Modules include a Generative Adversarial Network or GAN, a sequence-to-sequence model, or an ensemble of different models.
-
To do this, we can follow the same approach that we did while saving one single model as a checkpoint. Specifically, we can save a dictionary for each model's state via the model's state_dict and the corresponding optimizer via the optimizer's state_dict, along with saving any other items that might be required later to resume the training process. Like before, this is done by simply serializing a python dictionary via torch.save.
-
Again, like earlier, the .tar file extension is the common PyTorch convention to save such files.
- To return the saved dictionary, we initialize the different models and their corresponding optimizers and then load the dictionary using torch.load(). Like earlier, querying the dictionary would now yield us the required items.
Warm starting Model Using Parameters from a Different Model
It is often beneficial and desirable to transfer the weights of the optimizable parameters from one model to another. For example, we could initialize the weights of a model in question using some weights from another already trained model to give a warm start to our current model. This could help the current model converge much faster as it gets to have transferred knowledge from some already trained system.
Many times, it might be the case that the model weights we are trying to transfer are saved in the form of a dictionary (via state_dict) whose keys do not exactly match that of our current model that we plan to give a warm start to. Specifically, we might want to load from a partial state_dict, which is missing some keys, or load from a state_dict, which has more keys than the current model we wish to load into.
PyTorch allows us to easily do this by simply setting the strict argument to False in the load_state_dict() function call to ignore the non-matching keys.
Saving & Loading Model Across Devices
Deep neural networks can be trained using the CPU or the GPU as the hardware device. To this end, we will learn how to save and load PyTorch models across devices.
Save on GPU, Load on CPU
To load a model on a CPU device trained on a GPU device, we must pass torch.device('cpu') to the map_location argument of the torch.load() function.
This way, the storages underlying the tensors are dynamically remapped to the CPU device with the map_location argument.
Save on GPU, Load on GPU
To load a model on a GPU device that was as well trained and saved on a GPU, we need to convert the initialized model to a CUDA optimized model using model.to(torch.device('cuda')).
Also, to be able to use the loaded model on the gpu for inference or training, we need to ensure that all model inputs are on the GPU device as well. For this, we must put all inputs to the GPU using my_tensor = my_tensor.to(torch.device('cuda')).
Save on CPU, Load on GPU
To load a model on a GPU trained and saved on the CPU, pass the map_location argument in the torch.load() function as cuda:device_id.
This will load the model to the specified GPU device. After this, we need to call model.to(torch.device('cuda')) to convert the model's parameter tensors to CUDA tensors.
And like before, we need to be sure to use the .to(torch.device('cuda')) function call on all model inputs to feed the data to the CUDA-optimized model.
Example
We will now walk through a full-fledged example and demonstrate in code how PyTorch models could be saved and loaded to resume the training process later on.
We will be using the fashion MNIST dataset that is easily accessible in PyTorch using the torchvision module. The dataset contains images from 10 classes, and we will build a simple neural network-based classifier to classify images in the dataset into one of the 10 categories.
Importing Libraries
Let us first import all the necessary libraries:
Creating Helper Functions
We will now create two helper functions to save and load the best model checkpoint. Here we use the serialized dictionary method to save a general checkpoint discussed above.
Importing Data and Creating a Data Loader
Let us now download the dataset and create data loaders for the training and testing counterpart, like so:
Defining and Creating Model
Let us now create a custom model class to define a simple feed-forward neural network.
output
We have defined a simple feed-forward neural network to construct our classifier model. However, more appropriate models to classify images are available, like convolutional neural networks and so on. Moreover, our goal is to demonstrate the checkpoint saving and loading part in PyTorch. Hence we are moving ahead with a simple model.
Training the Network & Saving the Model
Let us define a function to train the model for several epochs. This function also keeps track of the validation loss as the model trains through the epochs and saves the checkpoint with the best validation loss. Finally, it returns the trained model.
We will first train the model for three epochs and save the best checkpoint.
Output
To ensure our best model was saved in the required directory, we can ls the files in the directory like so:
We will now use this saved checkpoint to initialize our model and continue the training for four more epochs. Like earlier, the best model shall again get saved in the same directory this time.
As a disclaimer, it makes more sense to save both the current checkpoint and the best checkpoint and resume training from the current checkpoint while saving the best checkpoint all along. For brevity, here we've just demonstrated the saving of the best model and used it to resume the training. The code can be easily modified to save both checkpoints in different directory structures and use the current one to resume the training process.
Loading the Model
Let us now load the saved model back to resume training. For that, we will instantiate the model and optimizer instances like so:
Output
loading the saved model:
Output
We have now initialized our instances using the saved checkpoint and can now use it for training the model further or inferencing from the model.
Continue Training
Let us now train the model for 4 more epochs. Then, as already noted, the best model shall again get saved in the same folder, and we can load this model back again for inferencing.
Below we now demonstrate how continued training shall work.
Output
That is all. Our best model would again get saved in the same folder structure as earlier, and we also have the trained_model instance that could be used for evaluation.
Ready to dive deep into the world of deep learning with PyTorch? Join our PyTorch for Deep Learning Course and embark on your AI journey.
Conclusion
In this article, we learned about the mechanics of saving and loading trained models in PyTorch. In particular,
- We first looked at the use cases for which model saving in PyTorch becomes necessary, and then we explored different ways to save trained models.
- We learned about saving models using state_dict, as pickle modules and TorchScript.
- We also learned about checkpointing models in PyTorch and saving multiple models in a file while exploring the saving and loading of PyTorch-trained models across devices.
- Finally, we worked through a full code example demonstrating a broad pipeline of how model saving and loading in PyTorch works.