This is a a gentle introduction to federated learning — a technique that makes machine learning more secure by training on decentralized data. We will also cover a real-life example of federated learning.

Introduction
As the field of machine learning grows, so does the major data privacy concerns with it. It is especially true when we train models on portable devices using sensitive data such as one’s daily routine, or say their heart activity for the week.
How do we train and improve these on-device machine learning models without sharing personally-identifiable data? Federated Learning tries to solve exactly this problem.
Why Federated Learning?
Or problems with centralized learning
Normally when we train a machine learning model, we need to have access to the data, which we can view freely. This way of training works just fine as long as the privacy of the data is not a concern.
But let’s say we were to work on the problem of cancer diagnoses. We would need a lot of data from health clinics. If they are not willing to share the data because of privacy concerns we are stuck.
This is exactly the problem with centralized learning; we can’t work with sensitive data.
Federated learning is a training technique that allows devices to learn collectively from a single shared model across all devices. The shared model is first trained on the server with some initial data to kickstart the training process. Each device then downloads the model and improves it using the data (federated data) present on the device.
In fact most of the data is born decentralized if I may put it this way. So using federated learning is a no brainer in such cases. There is no hassle of aggregating the data to a single server and we have added security of data as well.
Federated Learning has many pros compared to centralized learning:
- It ensures privacy since the data remains on the user’s device all the time, and the model owner never sees it. So technically, we are learning from unseen data. How cool is that?
- The device can use the model present on it locally to make predictions that result in a faster experience for the end-user.
- Since the training is decentralized and privacy is guaranteed, we can collect and train with data at a very scale. This results in intelligent models.
- We need less computation on the server as the models are trained remotely across thousands of devices.
Who is using it now ?

Many companies are using the power of Federated Learning including Google and Apple.
Google uses it in their Gboard keyboard. In their case a RNN is trained on decentralized device datasets. The model’s aim is to predict the next word on smartphone keyboards. The use of federated learning led to a substantial increase in next-word prediction accuracy of Gboard.
Apple uses it to improve Siri and its own QuickType keyboard. Again it uses the same principle as discussed above to train it. If you own an iPhone you may have noticed that the voice assistant on the phone will “wake up” when you say “Hey Siri,” but not when the same phrase comes from your friends or family.
If you have any questions about why this is super important, you have to watch Mark Rober’s April Fool’s day prank.
Now that we have the what and why out of the way, we can start to get our hands dirty. But wait a minute! Python or PyTorch doesn’t come out of the box with the facility to allow us to perform federated learning. Here comes PySyft to the rescue. Pysyft in simple terms is a wrapper around PyTorch and adds extra functionality to it. I will be discussing how to use PySyft in the next section.
Checkout their Github repo here 😉
Basic API details about Pysyft
In this section I will be covering some basic APIs one should know when using PySyft. First we must install the PySyft library on our machine. This can be done using pip install syft
.
if you still face problems installing
syft
, check out their installation guide here. If you still face problems, you can always raise an issue at their Github repo. Guys there at OpenMined are really active and will definitely help you to solve your issue. 🙂
So the first question you may be wondering is, how in the world do we train a model on data we don’t have access to or cannot see?
We usually perform deep learning on the machine which holds the data, but now we want to perform this kind of computation on some other remote machine. More specifically, we can no longer simply assume that the data is on our local machine. Thus, instead of using Torch tensor
, we’re now going to work with a pointer
to tensor
(this tensor
is stored on the remote location) with the help of PySyft.
I recommend that you go through the following code carefully as a lot of things have been introduced there, we will be using extensively when we code the training part of the model.
I have adapted the following from this official tutorial on PySyft. I felt that the tutorial on PySyft missed some points that I would like to cover so that you guys don’t go through the same problems I faced while digesting that tutorial :). Of course feel free to checkout the official tutorial, its great.
Also I will discuss the things that are relevant for us only, if you really want to know what is happening behind the scenes you can always check their Github repo.
I. Imports and Hook
import torch
import syft as sy
hook = sy.TorchHook(torch) # add extra functionality to PyTorch
We import PyTorch and PySyft, however we hook torch with syft with TorchHook
function. TheTorchHook
does the wrapping by adding all the additional functionality to PyTorch for doing Federated Learning and other Private AI techniques.
According to the PySyft’s doc here,
A Hook which overrides methods on PyTorch Tensors. The purpose of this class is to extend torch methods to allow for the moving of tensors from one worker to another and override torch methods to execute commands on one worker that are called on tensors controlled by the local worker.
This class is typically the first thing you will initialize when using PySyft with PyTorch because it is responsible for augmenting PyTorch with PySyft’s added functionality (such as remote execution).
II. Creating a Virtual Worker
Let me show you what I mean by a pointer
to tensor
. First, let’s create a virtual machine owned by some health clinic, say harmony_clinic
. We will be using this to simulate a machine present at a remote location. A thing to note is that syft
calls these machines as VirtualWorker
.
# create a machine owned by harmony clinic
harmony_clinic = sy.VirtualWorker(hook=hook,id='clinic')
III. Sending the Tensors
Now we know that the harmony_clinic
is at a remote location but it doesn’t have any data we can use. Let’s create some data so that we can send it to harmony_clinic
.
# we create a Tensor, maybe this is some gene sequence
dna = torch.tensor([0,1,2,1,2])
# and now I send it, and in turn we get a pointer back that
# points to that Tensor
dna_ptr = dna.send(harmony_clinic)
print(dna_ptr )
(Wrapper)>[PointerTensor | me:19886223167 -> clinic:88496800993]
We see that the PointerTensor
points from me
(which is us, PySyft creates this me
worker automatically) to harmony_clinic
. We also see some random numbers, these are actually object IDs that PySyft assigns to every object.
Now harmony_clinic
has the tensor that we sent. We can use harmony_clinic._objects
to see objects that harmony_clinic
currently has.
print(harmony_clinic._objects)
{88496800993: tensor([0, 1, 2, 1, 2])}
Notice the object ID for the tensor dna
. It is same as above in dna_ptr
.
IV. Getting back the Tensors
And in the same way, we can get a tensor
back from a remote location by using the .get()
function.
# get back dna
dna = dna_ptr.get()
print(dna)
# And as you can see... clinic no longer has the tensor dna anymore!!! It has moved back to our machine!
print(harmony_clinic._objects)
tensor([0, 1, 2, 1, 2])
{}
But that’s not machine learning, you might be wondering. I agree but now comes the fun part, we can use pointers to do arithmetic as we do in PyTorch. You will know what I mean in the following section.
V. Doing Deep Learning with Pointer Tensors
a = torch.tensor([3.14, 6.28]).send(harmony_clinic)
b = torch.tensor([6.14, 3.28]).send(harmony_clinic)
c = a + b
print(c)
(Wrapper)>[PointerTensor | me:62288919884 -> clinic:28157711005]
Something very interesting happened behind the scenes, i.e. when did c = a + b
on our machine, a command was sent to the remote machine that did that exact calculation, created a new tensor on its machine and then sent back a pointer to us which we now call c
.
The amazing thing is this API has been extended to all the PyTorch operations including Back propogation. Hurray !!
This means that we can use the same PyTorch code that we usually do when doing Machine Learning.
If you don’t believe me, checkout the following.
# we create two tensors and send it to bob
train = torch.tensor([2.4, 6.2], requires_grad=True).send(harmony_clinic)
label = torch.tensor([2, 6.]).send(harmony_clinic)
# we apply some function, in this case a rather simple one, just to show the idea, we use L1 loss
loss = (train-label).abs().sum()
# Yes, even .backward() works when working with Pointers
loss.backward()
# now we retreive back the train tensor
train = train.get()
print(train)
# If everything went well, we will see gradients accumulated
# in .grad attribute of train
print(train.grad)
tensor([2.4, 6.2], requires_grad=True)
tensor([1., 1.])
And we see that indeed !! So as you can see, the API is really quite flexible and capable of performing nearly any operation you would normally perform in Torch on remote data. This lays the groundwork for the next part of this article i.e. doing federated learning on MNIST data.
Now that we have gone through the introduction of the API, we can now move onto the interesting part. Next, I will discuss training of the model.
Real Life example on Federated Learning
I will be following the official PyTorch example on MNIST as a reference, you can look it up here.
Let us imagine a scenario where we want to build a handwritten digits classifier for schools to use. But we don’t have the data for training a model sadly. (Let’s assume MNIST data doesn’t even exist)
Let’s say there are two schools namely “Westside School” and “Grapevine High” near we live, however each school doesn’t have enough training data to train thier own model. But their combined data can be used to train a effective model that both the schools can use. However both schools are worried about the privacy of the data.
So we come to the rescue. We propose that we will train a model in a federated manner, that means that we don’t have to look at the data at both schools and still use the power of the combined dataset to train a good model. Both the schools liked our idea and allowed us to use their data.
So in our example, we will use MNIST data distributed between these two imaginary schools to simulate a federated learning real life scenario.
Imports and Model Architecture
We use all the basic imports that we normally require while doing any deep learning problem with PyTorch.
The thing we need extra is the PySyft and hooking it onto PyTorch to add all the extra goodness we need for federated learning to work, as we discussed in the introduction to API section.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import logging
# import Pysyft to help us to simulate federated leraning
import syft as sy
# hook PyTorch to PySyft i.e. add extra functionalities to support Federated Learning
# and other private AI tools
hook = sy.TorchHook(torch)
# we create two imaginary schools
westside_school = sy.VirtualWorker(hook, id="westside")
grapevine_high = sy.VirtualWorker(hook, id="grapevine")
Now we define hyper-parameters such as learning rate, batch size, test batch size etc.
# define the args
args = {
'use_cuda' : True,
'batch_size' : 64,
'test_batch_size' : 1000,
'lr' : 0.01,
'log_interval' : 10,
'epochs' : 10
}
# check to use GPU or not
use_cuda = args['use_cuda'] and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
Now we define a very simple CNN.
# create a simple CNN net
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels = 1, out_channels = 32, kernel_size = 3, stride = 1),
nn.ReLU(),
nn.Conv2d(in_channels=32,out_channels = 64, kernel_size = 3, stride = 1),
nn.ReLU()
)
self.fc = nn.Sequential(
nn.Linear(in_features=64*12*12, out_features=128),
nn.ReLU(),
nn.Linear(in_features=128, out_features=10),
)
self.dropout = nn.Dropout2d(0.25)
def forward(self, x):
x = self.conv(x)
x = F.max_pool2d(x,2)
x = x.view(-1, 64*12*12)
x = self.fc(x)
x = F.log_softmax(x, dim=1)
return x
Sending the data to schools
We load the data first and then transform the data into a federated dataset using .federate()
method. It does a couple of things for us:
- It splits the dataset in two parts (which was also done by the torch Data Loader as well)
- But the extra thing it does is it also sends this data across two remote workers, in our case the two schools.
We will then used this newly created federated dataset to iterate over remote batches during our training loop.
One thing to note is that in real life we won’t be sending the data to schools, instead the schools will already have the data and we will just have pointer to their data. Here we are just simulating a real life scenario.
The test dataset will be with us only (i.e. the local worker in PySyft’s terminlogy), so the code is same as the official guide.
# Now we take the help of PySyft's awesome API to prepare the data for us and
# distribute for us across 2 workers ie. two schools
# normally we dont have to distribute data, data is already there at the site.
# We are doing this just to simulate federated learning.
# Below code looks just like torch code with just some minor changes. This is what's nice about PySyft.
federated_train_loader = sy.FederatedDataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
.federate((grapevine_high, westside_school)),
batch_size=args['batch_size'], shuffle=True)
# test data remains with us locally
# this is the normal torch code to load test data from MNIST
# that we are all familiar with
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args['test_batch_size'], shuffle=True)
Training and Validation functions

Now each time we train the model, we need to send it to the right location for each batch. We used .send()
function that we learnt above to do this.
Then, we perform all the operations remotely with the same syntax like we’re doing local PyTorch. When we’re done, we get back the updated model using the .get()
method. Simple isn’t it ?
Note in the below
train
function that(data, target)
is a pair ofPointerTensor
.
In aPointerTensor
, we can get the worker it points to using the.location
attribute, and that is what precisely we are using to send the model to the correct location.
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
# iterate over federated data
for batch_idx, (data, target) in enumerate(train_loader):
# send the model to the remote location
model = model.send(data.location)
# the same torch code that we are use to
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
# this loss is a ptr to the tensor loss
# at the remote location
loss = F.nll_loss(output, target)
# call backward() on the loss ptr,
# that will send the command to call
# backward on the actual loss tensor
# present on the remote machine
loss.backward()
optimizer.step()
# get back the updated model
model.get()
if batch_idx % args['log_interval'] == 0:
# a thing to note is the variable loss was
# also created at remote worker, so we need to
# explicitly get it back
loss = loss.get()
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch,
batch_idx * args['batch_size'], # no of images done
len(train_loader) * args['batch_size'], # total images left
100. * batch_idx / len(train_loader),
loss.item()
)
)
The test function remains the same as it is run locally on our machine only whereas training happens remotely.
def test(model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
# add losses together
test_loss += F.nll_loss(output, target, reduction='sum').item()
# get the index of the max probability class
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
Train! Train! Train!
We can now start training the model at last and the best part is, we use the same code when we train the model locally.
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args['lr'])
logging.info("Starting training !!")
for epoch in range(1, args['epochs'] + 1):
train(args, model, device, federated_train_loader, optimizer, epoch)
test(model, device, test_loader)
# thats all we need to do XD
Conclusion
Using the exact same code as above, I was able to get accuracy of 98% which is quite good.
We were able to do all the training without even seeing the data present at those two remote locations and also without risking their privacy. So now both the schools can use this model to classify handwritten digits and everyone is happy.
That sums up the federated learning part.
Outro
If you made it this far, pat yourself on the back. I know that’s all quite a lot to take. So I will summarise everything we did for you guys.
- We learned the importance of federated learning, how is it different from the normal centralised machine learning approach.
- We learned all the basic details about PySyt’s API including working with tensors present on remote location using Pointer Tensors. We also saw that we can do normal PyTorch Tensor stuff using Pointer Tensor as well.
- Then we simulated a real life scenario where we created two imaginary schools that had the data with them and we used federated learning to combine the power of their data to create an effective model.
- We trained a model and hardly had to change the official PyTorch example on MNIST to a real Federated Learning scenario.
A thing to note is that this is a very basic federated learning scenario. There a lot of flaws still in this setup for example:
- We can learn information about the data present at remote location by looking at the change in the gradients of our model. A common way to remedy this is to use a Secure Aggregator that receives models from all the remote locations and takes their mean and only then sends back the aggregated model to us.
- Since we spent a lot of time building our model, we don’t want anyone at the remote location to see it. Right now, when we send our model to remote location, it is completely exposed.
- Right now the training is happening linearly i.e one worker after the another, this is causing bottlenecks in performance. What we want is to train this model in parallel across all workers since the training of the model on each worker is independent of each other.
- We could use differential privacy, where we inject some noise in the data at the remote location thereby making the training more secure.
Why did I write this article ?
This is something I feel one should ask himself whenever he/she is writing a post. So I will list some of the reasons that made me write this article in no particular order (P.S these are just my thoughts on this):
- I think people should understand the importance of privacy in dataset. Also I think not being able to work on sensitive data discourages researchers to work on the problems of people, so we resort to solving problems like Face Detection etc. How many people have you heard of or know that have worked on Cancer Cells dataset ? I bet not many.
- When I started reading about Federated Learning there was not any beginner friendly article that also explained the PySyft’s API in the same post. Of course you could look at the Syft’s tutorial and look at their blog on federated learning separately (which are both really great resource to learn), there were still things I wish they could have explained more, so I try to cover these in this post.
That’s it from my side 🙂
Thank you so much for reading this. Until next time!!
Some Extra Goodness just for you guys
These are some of things I’m sure you would enjoy:
- Check out Andrew Trask’s video in PyTorch Conference 2019, it’s really great. That’s the video that inspired me to learn more about privacy preserving AI.
- You can also checkout Udacity’s course on Private AI if you want to learn more.
- Here is the video of Federated Learning presented at TensorFlow Dev Summit 2019. It’s a really great video.
- Don’t forget to checkout PySyft’s Github repository here and OpenMined’s website here.