Document Scanning is a background segmentation problem which can be solved using various methods. It is one of the extensively used applications of computer vision. In this post, we are considering Document Scanning as a semantic segmentation problem. We will be using DeepLabv3 semantic segmentation architecture to train a Document Segmentation model on a custom dataset.
Document Scanner using Semantic Segmentation Architecture DeepLabv3 – PyTorch [TL;DR]
Previously, we explored classical computer vision techniques, in an effort to automate the pipeline. Check out the post “Automatic Document Scanner using OpenCV” where we created a Document Scanner using OpenCV entirely. However, as documented; imperfections were observed in some cases. The reason for failure was our biased assumption regarding the structure and placement of the documents and background variations.
In this article, we will train a semantic segmentation model on custom dataset to improve the results. The steps for creating a document segmentation model are as follows.
- Collect dataset and pre-process to increase the robustness with strong augmentation.
- Build a custom dataset class generator in PyTorch to load and pre-process image mask pairs.
- Select and load a suitable deep learning model for transfer learning.
- Choose appropriate loss function, evaluation metrics and train the model.
- Image Segmentation Prerequisites
- Why use a deep learning-based solution for Document Segmentation?
- Workflow for Training a Custom Semantic Segmentation Model
- Preparing Synthetic Dataset for Robust Document Segmentation
- A Custom Dataset Class for Loading Documents and Masks
- Loading Pre-trained DeeplabV3 Semantic Segmentation Models
- Selecting Loss and Metric Functions IoU and Dice
- Custom Training Image Segmentation Model
- Inference and Comparison of Results
- Summary: Document Segmentation
1. Image Segmentation Prerequisites
This post assumes that you are familiar with the basic concepts of Image Segmentation, Semantic segmentation and the workings of Pytorch. If not, we’ve got you covered.
- You can refer to our previous blog post, where Image segmentation has been covered in detail. You can also watch the video for the same from here: Image Segmentation | LearnOpenCV.
- We have a series of blog posts you can use to familiarize yourself with PyTorch. You can access them over here: Learn PyTorch.
2. Why use a deep learning-based solution for Document Segmentation?
Robustness. As seen in the previous post: Automatic Document Scanner using OpenCV | LearnOpenCV, for a document scanner to perform effectively across multiple scenarios is a challenging task. For a document scanner to be robust, the algorithm used for document extraction must be free of biased assumptions. The solution to this problem is to create a deep learning-based image segmentation model for document segmentation. This post will show how you can create and train a custom semantic segmentation model for the task using DeepLabv3 architecture in PyTorch.
3. Workflow for Training a Custom Semantic Segmentation Model
In this section, we’ll show you how to generate a synthetic dataset to train the document segmentation model. For creating the custom semantic segmentation model, we will use the pre-trained DeepLabV3 architectures. The backend that we are using is MobilenetV3-large.
The steps for creating a robust document segmentation model are as follows:
- As with any project, after defining the problem statement, the next crucial step is to figure out the dataset collection procedure, i.e. how do we go about collecting the dataset for the task?
- The aim is to create a robust document segmentation model that works competently in multiple scenarios. To do so, we will generate a Synthetic dataset using background and document images collected from various sources.
- With the synthetic dataset, we can move on to PyTorch to creating a custom Dataset class generator. It will be responsible for loading and preprocessing the image-mask pairs.
- Next, we’ll choose and load the deep learning model suitable for the task. We will use the DeeplabV3 architecture with the MobilenetV3-Large backbone readily available in PyTorch.
- Before we start training the model, the final component is selecting the appropriate loss functions and evaluation metrics. We’ll briefly discuss the two most common concepts, Intersection over Union and Dice Coefficient, used for segmentation problems and select the one most useful for our task.
4. Preparing Synthetic Dataset For Robust Document Segmentation
Let us start by first preparing the custom dataset. A general breakdown of time spent in any machine learning or deep learning project is estimated to be 80% on the dataset collection, preparation and analysis and 20% on the actual training and improvements.
We aim to create a robust document segmentation model. To do so, we need a dataset containing various documents in multiple backgrounds captured in different orientations. As you can imagine, collecting such a dataset is very time-consuming (typical for any project), so we take the other route.
Here, we generate a synthetic dataset that closely resembles the different problems (such as motion blur, camera noise, etc.) associated when capturing real-world images. This procedure helps us overcome some of the shortcomings of a manually created dataset and removes the hassle of capturing pictures and annotation.
4.1 Gathering and Pre-Processing of Document and Background Image
To generate a synthetic dataset, we need the following sets of images.
- Images of different types of Documents.
- A wide variety of Background images.
|Datasets||Total Images Present||Images taken (randomly)|
|1. Val and Test set from DocVQA||2573||700|
|2. IAM Handwriting Database||522||100|
|3. Denoising Dirty Documents||360||125|
|5. LRDE Document Dataset||125||125|
|6. SmartDoc QA||4260||85|
Table 1: Details regarding Document image sources and number of images used.
Following are the steps involved in pre-processing of images.
- Images in the above sources vary a lot in terms of dimensions. As a result, all documents were resized (while maintaining aspect ratio) with the maximum dimension size set to 640. This is done to induce some structure in the documents and decrease processing time.
- As all images are cropped documents, the mask for each document is generated by creating an array of the same shape filled with value 255.
- SmartDoc QA contains 4260 raw images, of which 85 documents were manually annotated and extracted.
Background images: One of the goals for generating a synthetic set is to simulate the situations where documents are placed in different backgrounds. We can use Google image search results to create a background image dataset. The dataset was created by downloading images resulting from queries such as “table images top view”, “laminate sheet close up image”, “Wooden table close up”, etc. The queries used were such that they resulted in images with different textures and colors.
To ease the download process, we used this fork of the Google Images Download repository. All images were either downloaded or converted to JPG format. Afterwards, all duplicate images were filtered out, and we were left with 1055 background images.
4.2 Procedure for Generating Synthetic Dataset for Document Segmentation
The above diagram shows the flow for generating one image and mask pair. Our main goal of synthetic data generation is to have documents placed across various backgrounds and positions. So, each document-mask pair is subjected to RandomPerspective (with RandomBrightnessContrast (50% probability) for documents) transformation 6 times, with a 70% probability that the transformation will be applied.
For the 6 transformed documents and mask images, 6 background images from the entire set are randomly selected. The transformed document and the chosen background images are then merged together. Each merged output pair (image and mask) is subjected to further augmentations to replicate the real-world scenario as closely as possible.
The augmentation set consists of the following modifications.
- Horizontal and Vertical Flipping
- Colour Jitter
- Channel Shuffle
- Random Brightness and Contrast
- One of: Image Compression, ISO Noise, Motion Blur
- One of: Random Shadow, Sun-Flare, RGB shift
- One of Optical Distortion, Grid Distortion or Elastic Transformation
All except Perspective transformation are applied using the Albumentations library.
A total of 8058 image-mask pairs were generated, out of which 6715 were selected for the training set, and the remaining 1343 made up the validation set.
5. A Custom Dataset Class for Loading Documents and Masks
A custom Dataset class is created to load and convert an image and mask pair into the appropriate format. All steps except the preprocess transformations for images are similar for the training and validation set. For the train set, an additional augmentation, “RandomGrayscale”, is applied to the images with a probability of 40%.
The additional augmentation is used to make training a little more challenging. This forces the model to focus more and better learn the difference between (any type of) document and background. The decision was taken after observing results from multiple experiments. When training a model for only grayscaled or only RGB images, both perform well but fail in places where the other seems to work well.
- Mask for each image consists of 2 channels – one for the background and the other for the document.
- All images and masks returned are first rescaled to the range [0. , 1.]. Images are further normalized according to ImageNet mean and std. dev. statistics.
The following block contains the code to prepare the document segmentation dataset.
import torchvision.transforms as torchvision_T def train_transforms(mean=(0.4611, 0.4359, 0.3905), std=(0.2193, 0.2150, 0.2109)): transforms = torchvision_T.Compose([ torchvision_T.ToTensor(), torchvision_T.RandomGrayscale(p=0.4), torchvision_T.Normalize(mean, std), ]) return transforms def common_transforms(mean=(0.4611, 0.4359, 0.3905), std=(0.2193, 0.2150, 0.2109)): transforms = torchvision_T.Compose([ torchvision_T.ToTensor(), torchvision_T.Normalize(mean, std), ]) return transforms class SegDataset(Dataset): def __init__(self, *, img_paths, mask_paths, image_size=(384, 384), data_type="train" ): self.data_type = data_type self.img_paths = img_paths self.mask_paths = mask_paths self.image_size = image_size if self.data_type == "train": self.transforms = train_transforms() else: self.transforms = common_transforms() def read_file(self, path): file = cv2.imread(path)[:, :, ::-1] file = cv2.resize(file, self.image_size, interpolation=cv2.INTER_NEAREST) return file def __len__(self): return len(self.img_paths) def __getitem__(self, index): image_path = self.img_paths[index] image = self.read_file(image_path) image = self.transforms(image) mask_path = self.mask_paths[index] gt_mask = self.read_file(mask_path).astype(np.int32) _mask = np.zeros((*self.image_size, 2), dtype=np.float32) # BACKGROUND _mask[:, :, 0] = np.where(gt_mask[:, :, 0] == 0, 1.0, 0.0) # DOCUMENT _mask[:, :, 1] = np.where(gt_mask[:, :, 0] == 255, 1.0, 0.0) mask = torch.from_numpy(_mask).permute(2, 0, 1) return image, mask
The dataset object for both training and validation sets is first initialized with appropriate arguments and then wrapped around a DataLoader object.
6. Loading Pre-trained DeeplabV3 Semantic Segmentation Models
Torchvision provides three pre-trained variants of the DeeplabV3 architecture. The difference between them is the backbone model.
For our problem of creating a robust document segmentation, DeepLabV3 with a MobileNetV3-Large backbone pre-trained model is used. The model is relatively smaller in size compared to the other variants but has a good mIoU score and high inference speed. It has over 11M+ parameters and was trained on a subset of COCO, using only the 20 categories in the Pascal VOC dataset.
We will fine-tune all the model layers as our target class differs significantly from the classes used for training the model.
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large from torchvision.models.segmentation import deeplabv3_resnet50, deeplabv3_resnet101 def prepare_model(backbone_model="mbv3", num_classes=2): weights = 'DEFAULT' # Initialize model with pre-trained weights. if backbone_model == "mbv3": model = deeplabv3_mobilenet_v3_large(weights=weights) elif backbone_model == "r50": model = deeplabv3_resnet50(weights=weights) elif backbone_model == "r101": model = deeplabv3_resnet101(weights=weights) else: raise ValueError("Wrong backbone model passed. Must be one of 'mbv3', 'r50' and 'r101' ") # Update the number of output channels for the output layer. # This will remove the pre-trained weights for the last layer. model.classifier = nn.LazyConv2d(num_classes, 1) model.aux_classifier = nn.LazyConv2d(num_classes, 1) return model # Dummy Initialization. model = prepare_model(num_classes=2) model.train() # In train mode, batch size needs to be at least 2. out = model(torch.randn((2, 3, 384, 384))) print(out['out'].shape) # torch.Size([2, 2, 384, 384])
7. Selecting Loss and Metric Functions IoU and Dice
So far, we have worked through the details of the Dataset class and a function to initialise the DeeplabV3 model. Next, we need to select the appropriate loss function and evaluation metric suitable for the task.
Two methods are widely preferred for segmentation problems. Both of them can be used as either loss function or evaluation metric. They are:
- Intersection over Union (IoU) is a metric often used to assess the model’s accuracy in segmentation problems. It provides a more intuitive basis for accuracy that is not biased by the (unbalanced) percentage of pixels from any particular class.
- Dice Coefficient (otherwise known as the F1-Score) is another metric used in the segmentation context and is very similar to IoU. Simply put, the metric is twice the overlap area divided by the total number of pixels in both ground truth and prediction.
Both metrics range from 0 to 1 and are positively correlated. A significant difference between IoU and Dice is seen when penalizing the wrong predictions. IoU penalizes FP and FN predictions twice more than Dice.
With some minor differences (w.r.t. processing model predictions), we can use both methods as loss and metric function (loss = 1 – metric). The core computation remains the same. For this reason, a joint function is defined that returns the metric value.
7.1 Implementing Loss Function and Evaluation Metric
We will train the custom document segmentation model using a Combo Loss of IoU and Binary Cross-entropy and track IoU as an evaluation metric.
def intermediate_metric_calculation(predictions, targets, use_dice=False, smooth=1e-6, dims=(2,3)): # dims corresponding to image height and width: [B, C, H, W]. # Intersection: |G ∩ P|. Shape: (batch_size, num_classes) intersection = (predictions * targets).sum(dim=dims) + smooth # Summation: |G| + |P|. Shape: (batch_size, num_classes). summation = (predictions.sum(dim=dims) + targets.sum(dim=dims)) + smooth if use_dice: # Dice Shape: (batch_size, num_classes) metric = (2.0 * intersection) / summation else: # Union. Shape: (batch_size, num_classes) union = summation - intersection # IoU Shape: (batch_size, num_classes) metric = intersection / union # Compute the mean over the remaining axes (batch and classes). # Shape: Scalar total = metric.mean() return total
The main code for the Loss Class:
# Normalize model predictions. predictions = torch.sigmoid(predictions) # Calculate pixel-wise loss for both channels. Shape: Scalar pixel_loss = F.binary_cross_entropy(predictions, targets, reduction="mean") mask_loss = 1 - intermediate_metric_calculation(predictions, targets, use_dice=self.use_dice, smooth=self.smooth) # Return total_loss total_loss = mask_loss + pixel_loss
Similarly, for the Metric Class:
# Convert unnormalized predictions into one-hot encoded across channels. # Shape: (B, #C, H, W) predictions = convert_2_onehot(predictions, num_classes=self.num_classes) # Return metric metric = intermediate_metric_calculation(predictions, targets, use_dice=self.use_dice, smooth=self.smooth)
The convert_2_onehot function is a separate helper function for converting model predictions across channels into one-hot values.
8. Custom Training Document Segmentation Model
Now that we have defined all the components required, we are ready to train our custom semantic segmentation model for document segmentation. The hyperparameters and results for the final training were as follows.
|1. Backbone Model||MobileNetV3-Large|
|2. Image Shape||(384 x 384 x 3)|
|3. Mask Shape||(384 x 384 x 2)|
|4. Number of output channels||2 |
(0 – background, 1 – document)
|5. Number of Epochs Trained||50|
|6. Batch Size||64|
|8. Learning Rate||0.0001 (Constant)|
|9. Loss Function/s||Binary Cross entropy + Intersection over union |
(BCE + IoU)
|10. Evaluation Metric||Intersection over Union (IoU)|
|11. Best Loss Values||Training – 0.027|
|Validation – 0.076|
|12. Best Metric Scores||Training – 0.989|
|Validation – 0.976|
Table 2: Training hyperparameters and final scores
Deeplabv3 with Resnet-50 backbone using the same hyperparameters mentioned in the table below, trained for 25 epochs, gave slightly better results than the MobileNetV3-Large backbone. Still, the latter is preferred as the difference in results was only about 0.07 and is lightweight, allowing us to conduct more experiments quickly.
During experimentation, multiple combinations of different hyperparameters were tested. These include Dice + BCE loss function, single channel output, different image size, all grayscale images as well as all RGB images, with and without class weight masks etc. The hyperparameters stated in the above table resulted in best performance.
9. Inference and Comparison of Results
A test set of 51 images was created, which consists of 23 (including failure case) images from the previous post and 28 newly captured images.
The results are shown for the 25 images that exhibit major differences.
Fig: Document Segmentation Comparison Results
Further testing of the models was done on the dataset used in the DocUNet: Document Image Unwarping paper.
Testing on the DocUNet dataset indicates there’s still room for further improvements in the synthetic dataset generation procedure and the training process.
10. Summary: Document Segmentation
In this post, our goal was to improve our previously used document extraction approach with something that’s much more robust. To do so, we moved away from the traditional CV algorithms and created a deep learning-based custom semantic segmentation model for document segmentation. A deep learning-based approach allows us to be free of any assumptions we had to make when working with traditional CV algorithms.
This post covered generating a synthetic dataset, defining appropriate loss and metric functions for image segmentation and training a custom DeeplabV3 model in PyTorch. The results presented in the Inference and Comparison Results section indicate that a deep learning-based approach is a significant improvement over our previously used method. But that’s not it; as demonstrated from testing on the DocUNet dataset, there’s still room for improvement.
- Rethinking Atrous Convolution for Semantic Image Segmentation
- Searching for MobileNetV3