Imagine, one day you have an amazing idea for your machine learning project. You write down all the details on a piece of paper- the model architecture, the optimizer, the dataset. And now you just have code it up and do some hyperparameter tuning to put it to application.
So, you light up your machine and start coding. But suddenly it hits you, you need to go through the hard work of creating batches out of the data, writing loops to iterate over batches and epochs, debugging any issues that may arise while doing so, repeating the same for the validation set and the list goes on. It turns out to be a headache before it even started.
But not anymore. PyTorch Lightning is here to save your day. Not only does it automatically do the hard work for you but it also structures your code to make it more scalable. It comes fully packed with awesome features that will enhance your machine learning experience. Beginners should definitely give it a go.
Throughout this article we will learn how can Lightning be used along with PyTorch to make development easy and reproducible.
Roadmap
With this post, I aim to help people get to know PyTorch Lightning. From now on I will be referring to PyTorch Lightning as Lightning.
I will begin with a brief introduction to the new library and its underlying principles so that you can build research-friendly neural network models from scratch.
This tutorial assumes that you have prior knowledge of how a neural network works. It also assumes you are familiar with the PyTorch framework. Even if you are not familiar, you will be alright. For PyTorch users, this tutorial may serve as a medium to encourage them to include Lightening in their PyTorch code.
Let us start with some basic introduction.
What is PyTorch?
Based on the Torch library, PyTorch is an open-source machine learning library. PyTorch is imperative, which means computations run immediately, and the user need not wait to write the full code before checking if it works or not. We can efficiently run a part of the code and inspect it in real-time. The library is python based and built for providing flexibility as a deep learning development platform.
PyTorch is extremely “pythonic” in nature. It is basically a NumPy substitute that utilizes the computation benefits of powerful GPUs
PyTorch enables the support of dynamic computational graphs that allows us to change the network on the fly.
The Catch
PyTorch is an excellent framework, great for researchers. But after a certain point, it involves more engineering than researching.
As I mentioned in the introduction, the hard work starts taking over the research work. The focus shifts from training and tuning the model to correctly implementing the following features
- Re-coding a training loop
- Multi-cluster training
- 16-bit precision
- Early-stopping
- Model loading/saving
- etc…
Even though they may be simple to implement, we would still end up losing precious time and might risk a chance of making a mistake while coding these up leading to time being wasted in debugging.
Consider an example. We are training a model. We want that after 100 epochs it stops and saves the trained model into a .pth
file. But we made a mistake in writing the model-saving code. The thing about python is that it does not show an error until it runs into one. So, after 10 hours of training, we run into an error. and our model did not save. And just like that, the 10 hours go down the drain. How frustrating would this be?
Enter Lightning
Lightning is a very lightweight wrapper on PyTorch. This means you don’t have to learn a new library. It defers the core training and validation logic to you and automates the rest. It guarantees tested and correct code with the best modern practices for the automated parts.
So we can actually save those 10 hours by carefully organizing our code in Lightning modules.
As the name suggests, Lightning is related to closely PyTorch: not only do they share their roots at Facebook but also Lightning is a wrapper for PyTorch itself. In fact, the core foundation of PyTorch Lightning is built upon PyTorch.
In its true sense, Lightning is a structuring tool for your PyTorch code. You just have to provide the bare minimum details (Eg. number of epoch, optimizer, etc). The rest will be automated by Lightning.
By using Lightning, you make sure that all the tricky pieces of code work for you and you can focus on the real research:
- Hyperparameter tuning
- Finding the best model for a problem
- Visualizing results
Lightning ensures that when your network becomes complex your code doesn’t
It ensures that you focus on the real deal and not worry about how to run your model on multiple GPUs or speeding up the code. Lightning will handle that for you.
But what does this mean for you? It means that this framework is designed to be extremely extensible while making state of the art AI research techniques (like multi-GPU training) trivial.
Quick MNIST Classifier on Google Colab
I will be showing you exactly how you can build a MNIST classifier using Lightning. I will be walking you through a very small network with 99.4% accuracy on MNIST Validation set using <8k trainable parameters. I tried re-implementing the code using PyTorch-Lightening and added my own intuitions and explanations.
We shall do this as quickly as possible so that we can move on to even more interesting details of Lightning
The Main Aspects of a Lightning Model
The basic and essential chunks of a Neural Network in Lightning are the following
- Model architecture — Restructuring
- Data — Restructuring
- Forward pass — Restructuring
- Optimizer — Restructuring
- Training Step — Restructuring
- Training and Validation Loops (Lightning Trainer) — Abstraction
We can clearly see that they are contained in 2 categories: Restructuring and Abstraction
Restructuring
Restructuring refers to keeping code in its respective place in the Lightning Module. It has just been arranged in the functions of Lightning Module known as Callbacks. They have a special meaning to the Lightning because it helps it understand the functionality of the function
It is to be noted that there is no change in the PyTorch code during the restructuring
Abstraction
The boilerplate code is abstracted by the Lightning trainer. It automates most of the code for us.
Now there is no need to write separate code for saving your model or iterating over batches. Its is now abstracted into the Trainer
What does it contain?
Lightning provides us with the following methods of its class pl.LightningModule
that help in structuring the code. They refer to them as Callbacks:
forward
— This is the good old forward method that we have in nn.Module in PyTorch. It remains exactly the same in Lightning.training_step
— This contains the commands that are to be executed when we begin training. We usually call for a forward pass in here for the training data. Its sister functions aretesting_step
andvalidation_step
training_epoch_end
— As the name suggests, this callback determines what will be done with the results (the outcome of a forward pass) at the end of an epoch. Its sister functions aretesting_epoch_end
andvalidation_epoch_end
train_dataloader
— This method allows us to set-up the dataset for training and returns a Dataloader object from torch.utils.data module. Its sister functions aretest_dataloader
andval_dataloader
configure_optimizers
— It sets up the optimizers that we might want to use, such as Adam, SGD, etc. We can even return 2 optimizers (in case of a GAN)training_end
— It contains the piece of code that will be executed when training ends- and many more such amazing callbacks here
Coding an MNIST Classifier
Now let’s dive right into coding so that we can get a hands on experience with Lightning
Installing Lightning
Run the following to install Lightning on Google Colab
!pip install pytorch_lightning
You will have to restart the runtime for some new changes to be reflected
Do not forget to select the GPU. Go to Edit->Notebook Settings->Hardware Accelerator and select GPU in Google Colab Notebook
Import Libraries
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
1. The Model
We will be defining our own class called smallAndSmartClassifier
and we will be inheriting pl.LightningModule
from Lightning
Let’s start building the model
class smallAndSmartModel(pl.LightningModule):
def __init__(self):
super(smallAndSmartModel, self).__init__()
self.layer1 = torch.nn.Sequential(
torch.nn.Conv2d(1,28,kernel_size=5),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2))
self.layer2 = torch.nn.Sequential(
torch.nn.Conv2d(28,10,kernel_size=2),
torch.nn.ReLU(),
torch.nn.MaxPool2d(kernel_size=2))
self.dropout1=torch.nn.Dropout(0.25)
self.fc1=torch.nn.Linear(250,18)
self.dropout2=torch.nn.Dropout(0.08)
self.fc2=torch.nn.Linear(18,10)
2. Data Loading
class smallAndSmartModel(pl.LightningModule):
#This contains the manupulation on data that needs to be done only once such as downloading it
def prepare_data(self):
MNIST(os.getcwd(), train=True, download =True)
MNIST(os.getcwd(), train=False, download =True)
def train_dataloader(self):
#This is an essential function. Needs to be included in the code
#See here i have set download to false as it is already downloaded in prepare_data
mnist_train=MNIST(os.getcwd(), train=True, download =False,transform=transforms.ToTensor())
#Dividing into validation and training set
self.train_set, self.val_set= random_split(mnist_train,[55000,5000])
return DataLoader(self.train_set,batch_size=128)
def val_dataloader(self):
# OPTIONAL
return DataLoader(self.val_set, batch_size=128)
def test_dataloader(self):
# OPTIONAL
return DataLoader(MNIST(os.getcwd(), train=False, download=False, transform=transforms.ToTensor()), batch_size=128)
The train_dataloader
, test_dataloader
and val_dataloader
are reserved functions in pl.LightningModule
. We use them as wrappers for loading our data.
It is necessary to write the code in these functions just because they have a special meaning in Lightning, just like how forward has in nn.module
Each of these is responsible for returning the appropriate data split. Lightning structures it in a way so that it is very clear how the data is being manipulated. If you ever read someone else’s code that isn’t structured like this (like most GitHub codes), you won’t be able to figure out how they manipulated their data.
Lightning even allows multiple data loaders for testing or validating.
3. Forward Pass
class smallAndSmartModel(pl.LightningModule):
def forward(self,x):
x=self.layer1(x)
x=self.layer2(x)
x=self.dropout1(x)
x=torch.relu(self.fc1(x.view(x.size(0), -1)))
x=F.leaky_relu(self.dropout2(x))
return F.softmax(self.fc2(x))
This is the forward pass — where the calculation process takes place and we generate the values for the output layers from the inputs data.
Users of PyTorch may notice that there is no change in its implementation
4. Optimizer
class smallAndSmartModel(pl.LightningModule):
def configure_optimizers(self):
# Essential fuction
#we are using Adam optimizer for our model
return torch.optim.Adam(self.parameters())
This required function returns the kind of optimizer we require. Interestingly Lightning provides us with the wrapper configure_optimizers
, which allows us to even return multiple optimizers with ease (for example in GANs)
5. Training Step (The interesting part)
class smallAndSmartModel(pl.LightningModule):
def training_step(self,batch,batch_idx):
#extracting input and output from the batch
x,labels=batch
#doing a forward pass
pred=self.forward(x)
#calculating the loss
loss = F.nnl_loss(pred, labels)
#logs
logs={"train_loss": loss}
output={
#REQUIRED: It ie required for us to return "loss"
"loss": loss,
#optional for logging purposes
"log": logs
}
return output
This step is called for every batch in our dataset. Some key operations that occur in this function are:
- The actual forward pass is made on the input to get the outcome
pred
from the model - The loss is calculated on the batch
loss
dictionary is prepared- An
output
dictionary is returned
It is essential for training_step
to return a dictionary containing loss
. Any other data returned is optional
6. The Lightning Trainer ( Where Magic Happens)
Obviously, there is no magic. But when I tell you what Lightning Trainer is capable of, you won’t refrain from claiming that indeed, it is charming and exquisite.
#abstracts the training, val and test loops
#using one gpu given to us by google colab for max 40 epochs
myTrainer=pl.Trainer(gpus=1,max_nb_epochs=100)
model=smallAndSmartModel()
myTrainer.fit(model)
The Trainer is the heart of PyTorch Lightning. This is where all the abstractions take place. It abstracts the most obvious pieces of code such as:
- The batch iteration
- The epoch iteration
- Calling of
optimizer.step()
- The validation loop
Now you don’t have to worry about engineering these steps. The Trainer does that for you. You just have to make sure that your code is well structured as explained in the above sections.
Lightning Trainer Flags
The trainer provides some very helpful flags. We can assign values to these flags to configure our classifier’s behavior.
gpus
— Number of GPUs you want to train onmax_epochs
— Stop training once this number of epochs is reachedmin_epoch
— Force training for at least these many epochsweights_save_path
— Directory of where to save weights if specified.precision
— Full(32 bit) or half(16 bit)- and many more
Perks of Lightning Trainer
By using the Trainer, you automatically get the following tools and features:
- Training and validation loop
- Tensorboard logging
- Early-stopping
- Model checkpointing
- The ability to resume training from wherever you left
Why should I start using PyTorch Lightning?
That’s the question you should be asking me after I told you so much about Pytorch Lightning. I will answer this by letting you in on my love for Lightning
1. Peace of Mind (Structured Code)
When I look at how the code is structured in Lightning, it feels almost natural and intuitive to put it there. The structuring ensures that I have a step-by-step strategy of developing my classifier from scratch. It is as if it makes me more confident in developing my models.
2. Simplistic
The steps to make solution for machine learning are now very simple and intuitive.
Now, to come up with a solution using Lightning, I know that I need to proceed by preparing data, adding optimizers, add the training step, and so on. This helps me in moving along with the flow of ideas in my mind.
3. Grouping the relevant together
The best thing about Lightning is that each process is separated from the other in the Lightningmodule
. That’s the benefit of structuring.
training_step
contains information about the training step and not about the validation step or about the optimizer. It makes things more clear for me
4. No True Change in Code required
Since Lightning is a wrapper for PyTorch, I did not have to learn a new language. Also, if I want to make very complex training steps I can easily do that without compromising on the flexibility of PyTorch.
Those who are familiar with PyTorch will find the transition to be extremely smooth.
5. The Lightning Trainer — Automation
The Trainer just wins it all. It automates most of the complex tasks for me.
In the case of GPUs, I don’t have to worry about converting my tensors to tensor.to(device=cuda)
. It automatically figures out the details. I just have to set a few flags. With this, I can even enable 16-bit precision, auto-cluster saving, auto-learning-rate-finder, Tensorboard visualization, etc.
By using the Trainer, I’m not only getting some very neat algorithms but I am also getting the guarantee that they will work correctly. Now that’s one less thing for me to worry about. And I can focus on my real research.
My personal favorite is Tensorboard logging and resuming training from where I left it.
Whom does Lightning caters to?
Lightning is best for scholars and researchers who are working on developing the best strategies to tackle a problem. Lightning takes away the unnecessary engineering from them and provides with a clean environment to perform relevant research.
I also believe that early PyTorch users should start using Lightning so that their thinking process becomes structured and more intuitive. Also, they might find it amazing to have so many perks at their disposal, ready to be exploited.
Congratulations
Now that you are acquainted with PyTorch Lightning, I hope you will start using Lightning (especially if you are a researcher) and fall in love with its amazing features.
That’s all from me. If you liked my little introduction to Lightning do share feedback
Keep learning and have fun!!