fbpx
Subscribe for More
Subscribe for More
Edit Content
Click on the Edit Content button to edit/add the content.

A Step-by-Step Tutorial on Image Segmentation using Tensorflow Hub

Tensorflow Hub segmentation feature image

In this post, we will learn how to perform semantic image segmentation using pre-trained models available in TensorFlow Hub. TensorFlow Hub is a library and platform designed for sharing, discovering, and reusing pre-trained machine learning models. The primary goal of TensorFlow Hub is to simplify the process of reusing existing models, thereby promoting collaboration, reducing redundant work, and accelerating research and development in machine learning. Users can search for pre-trained models, called modules, that have been contributed by the community or provided by Google. These modules can be easily integrated into a user’s own machine learning projects with just a few lines of code.

Image Segmentation is analogous to image classification but at the pixel level. The goal of image segmentation is to simplify the representation of an image and make it more meaningful for analysis or further processing. In other words, it aims to separate the important parts of an image, such as objects or areas of interest, from the background or irrelevant areas. You can read more about Image Segmentation in our introductory post on the subject.

In this example, we will use an image segmentation model camvid-hrnetv2-w48 that was trained on CamVid (Cambridge-driving Labeled Video Database), which is a driving and scene understanding dataset containing images extracted from five video sequences taken during real-world driving scenarios. The dataset contains 32 classes. Several other image segmentation models can be found here as well.

import os
import numpy as np
import cv2
import zipfile
import requests
import glob as glob

import tensorflow as tf
import tensorflow_hub as hub

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

import warnings
import logging
import absl

# Filter absl warnings
warnings.filterwarnings("ignore", module="absl")

# Capture all warnings in the logging system
logging.captureWarnings(True)

# Set the absl logger level to 'error' to suppress warnings
absl_logger = logging.getLogger("absl")
absl_logger.setLevel(logging.ERROR)

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

Download Sample (CamVid) Images

def download_file(url, save_name):
    url = url
    file = requests.get(url)

    open(save_name, 'wb').write(file.content)
def unzip(zip_file=None):
    try:
        with zipfile.ZipFile(zip_file) as z:
            z.extractall("./")
            print("Extracted all")
    except:
        print("Invalid file")
download_file( 
    'https://www.dropbox.com/s/5jhbvmqgzbzl9fd/camvid_images.zip?dl=1',
    'camvid_images.zip'
)
    
unzip(zip_file='camvid_images.zip')
Extracted all

Display Sample Images

image_paths = sorted(glob.glob('camvid_images' + '/*.png'))

for idx in range(len(image_paths)):
    print(image_paths[idx])
camvid_images/camvid_sample_1.png
camvid_images/camvid_sample_2.png
camvid_images/camvid_sample_3.png
camvid_images/camvid_sample_4.png
def load_image(path):

    image = cv2.imread(path)
    
    # Convert image in BGR format to RGB.
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Add a batch dimension which is required by the model.
    image = np.expand_dims(image, axis=0)/255.0
    
    return image
images = []
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(16, 12))

for idx, axis in enumerate(ax.flat):
    image = load_image(image_paths[idx])
    images.append(image)
    axis.imshow(image[0])
    axis.axis('off')
camvid test images

Define a Dictionary that Maps Class IDs to Class Names and Class Colors

class_index is a dictionary that maps all 32 classes in the CamVid dataset with their associated class IDs and RGB color labels.

class_index = \
    {
         0: [(64, 128, 64),  'Animal'],
         1: [(192, 0, 128),  'Archway'],
         2: [(0, 128, 192),  'Bicyclist'],
         3: [(0, 128, 64),   'Bridge'],
         4: [(128, 0, 0),    'Building'],
         5: [(64, 0, 128),   'Car'],
         6: [(64, 0, 192),   'Cart/Luggage/Pram'],
         7: [(192, 128, 64), 'Child'],
         8: [(192, 192, 128),'Column Pole'],
         9: [(64, 64, 128),  'Fence'],
        10: [(128, 0, 192),  'LaneMkgs Driv'],
        11: [(192, 0, 64),   'LaneMkgs NonDriv'],
        12: [(128, 128, 64), 'Misc Text'],
        13: [(192, 0, 192),  'Motorcycle/Scooter'],
        14: [(128, 64, 64),  'Other Moving'],
        15: [(64, 192, 128), 'Parking Block'],
        16: [(64, 64, 0),    'Pedestrian'],
        17: [(128, 64, 128), 'Road'],
        18: [(128, 128, 192),'Road Shoulder'],
        19: [(0, 0, 192),    'Sidewalk'],
        20: [(192, 128, 128),'Sign Symbol'],
        21: [(128, 128, 128),'Sky'],
        22: [(64, 128, 192), 'SUV/Pickup/Truck'],
        23: [(0, 0, 64),     'Traffic Cone'],
        24: [(0, 64, 64),    'Traffic Light'],
        25: [(192, 64, 128), 'Train'],
        26: [(128, 128, 0),  'Tree'],
        27: [(192, 128, 192),'Truck/Bus'],
        28: [(64, 0, 64),    'Tunnel'],
        29: [(192, 192, 0),  'Vegetation Misc'],
        30: [(0, 0, 0),      'Void'],
        31: [(64, 192, 0),   'Wall']  
    }

Model Inference using TensorFlow Hub

TensorFlow Hub contains many different pre-trained segmentation models. Here we will use the High-Resolution Network (HRNet) segmentation model trained on CamVid (camvid-hrnetv2-w48). The model has been pre-trained on the Imagenet ILSVRC-2012 classification task and fine-tuned on CamVid.

Load the Model from TensorFlow Hub

We can load the model into memory using the URL to the model page.

model_url =  'https://tfhub.dev/google/HRNet/camvid-hrnetv2-w48/1'
print('loading model: ', model_url)

seg_model = hub.load(model_url)
print('\nmodel loaded!')
loading model...
model loaded!

Perform Inference

Before we formalize the code to process several images and post-process the results, let’s first see how to perform inference on a single image and study the output from the model.

Call the Model’s precict() Method

# Make a prediction using the first image in the list of images.
pred_mask = seg_model.predict(images[0])

# The predicted mask has the following shape: [B, H, W, C].
print('Shape of predicted mask: ', pred_mask.shape)
Shape of predicted mask:  (1, 720, 960, 33)

Post-Process the Predicted Segmentation Mask

The predicted segmentation mask returned by the model contains a separate channel for each class. Each channel contains the probability that a given pixel from the input image is associated with the class for that channel. This data, therefore, requires some post-processing to obtain meaningful results. Several steps need to be performed to arrive at a final visual representation.

  1. Remove the batch dimension and the background class.
  2. Assign a class label to every pixel in the image based on the highest probability score across all channels.
  3. The previous step results in a single-channel image that contains the class labels for each pixel. We, therefore, need to map those class IDs to RGB values so we can visualize the results as a color-coded segmentation map.

Remove Batch Dimension and Background Class

# Convert tensor to numpy array.
pred_mask = pred_mask.numpy()

# The 1st label is the background class added by the model, but we can remove it for this dataset.
pred_mask = pred_mask[:,:,:,1:]

# We also need to remove the batch dimension.
pred_mask = np.squeeze(pred_mask)

# Print the shape to confirm: [H, W, C]. 
print('Shape of predicted mask after removal of batch dimension and background class: ', pred_mask.shape)
Shape of predicted mask after removal of batch dimension and background class:  (720, 960, 32)

Visualize the Intermediate Results

# Each channel in `pred_mask` contains the probabilities that the pixels 
# in the original image are associated with the class for that channel.
plt.figure(figsize=(20,6))

plt.subplot(1,3,1)
plt.title('Input Image', fontsize=14)
plt.imshow(np.squeeze(images[0]))

plt.subplot(1,3,2)
plt.title('Predictions for Class: Road', fontsize=14)
plt.imshow(pred_mask[:,:,17], cmap='gray');  # Class 17 corresponds to the 'road' class
plt.axis('off')

plt.subplot(1,3,3)
plt.title('Predictions for Class: Sky', fontsize=14)
plt.imshow(pred_mask[:,:,21], cmap='gray');  # Class 21 corresponds to the 'sky' class
plt.axis('off');
segmentation prediction sample images

Assign Each Pixel a Class Label

Here we assign every pixel in the image with a class ID based on the class with the highest probability. We can visualize this as a grayscale image. In the code cell below, we will display just the top portion of the image to highlight a few of the class assignments.

# Assign each pixel in the image a class ID based on the channel that contains the  
# highest probability score. This can be implemented using the `argmax` function.
pred_mask_class = np.argmax(pred_mask, axis=-1)

plt.figure(figsize=(15,5)); 

plt.subplot(1,2,1)
plt.title('Input Image', fontsize=12)
plt.imshow(np.squeeze(images[0]))

plt.subplot(1,2,2)
plt.title('Segmentation Mask', fontsize=12)
plt.imshow(pred_mask_class, cmap='gray') 
plt.gca().add_patch(Rectangle((450,200),200,3, edgecolor='red', facecolor='none', lw=.5));
Sample grayscale segmentation map

Let’s now inspect a small region of the segmentation mask to better understand how the values map to class IDs. For reference, the top portion (200 rows) of the segmentation mask (pred_mask_class) have been overlayed on the input image. Notice that regions in the segmentation mask correspond to distinct regions in the input image (e.g., buildings, sky, trees).

Example grayscale segmentation results overlayed on the original image
# Print the class IDs from the last row in the above image.
print(pred_mask_class[200,450:650])
[ 4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4
  4  4  4  4  4  4  4  4 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21
 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21
 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21
 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21
 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21
 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21 21
 26 26 21 21 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26 26
 26 26 26 26 26 26 26 26]

Notice that the values in pred_mask_class for the small section indicated by the red rectangle correspond to the class IDs for buildings, sky, and trees.

Convert the Single Channel Mask to a Color Representation

We will also need to make use of the function below that will convert a single channel mask to an RGB representation for visualization purposes. Each class ID in the single-channel mask will be converted to a different color according to the class_index dictionary mapping.

# Function to convert a single channel mask representation to an RGB mask.
def class_to_rgb(mask_class, class_index):
    
    # Create RGB channels 
    r_map = np.zeros_like(mask_class).astype(np.uint8)
    g_map = np.zeros_like(mask_class).astype(np.uint8)
    b_map = np.zeros_like(mask_class).astype(np.uint8)
    
    # Populate RGB color channels based on the color assigned to each class.
    for class_id in range(len(class_index)):
        index = mask_class == class_id
        r_map[index] = class_index[class_id][0][0]
        g_map[index] = class_index[class_id][0][1]
        b_map[index] = class_index[class_id][0][2]
        
    seg_map_rgb = np.stack([r_map, g_map, b_map], axis=2)
        
    return seg_map_rgb

Convert the grayscale segmentation mask to a color segmentation mask and display the results.

pred_mask_rgb = class_to_rgb(pred_mask_class, class_index)  

plt.figure(figsize=(20,8))

plt.subplot(1,3,1)
plt.title('Input Image', fontsize=14)
plt.imshow(np.squeeze(images[0]))
plt.axis('off')

plt.subplot(1,3,2)
plt.title('Grayscale Segmentation', fontsize=14)
plt.imshow(pred_mask_class, cmap='gray') 
plt.axis('off')

plt.subplot(1,3,3)
plt.title('Color Segmentation', fontsize=14)
plt.imshow(pred_mask_rgb, cmap='gray')  
plt.axis('off');
image segmentation map examples

Formalize the Implementation

In this section, we will formalize the implementation and will need to define some additional convenience functions.

image_overlay()

image_overlay() is a helper function to overlay an RGB mask on top of the original image to better appreciate how the predictions line up with the original image.

# Function to overlay a segmentation map on top of an RGB image.
def image_overlay(image, seg_map_rgb):
    
    alpha = 1.0 # Transparency for the original image.
    beta  = 0.6 # Transparency for the segmentation map.
    gamma = 0.0 # Scalar added to each sum.
    
    image = (image*255.0).astype(np.uint8)
    seg_map_rgb = cv2.cvtColor(seg_map_rgb, cv2.COLOR_RGB2BGR)
      
    image = cv2.addWeighted(image, alpha, seg_map_rgb, beta, gamma)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    return image

run_inference()

To perform inference on several images, we define the function below, which accepts a list of images and a pre-trained model. This function also handles all of the post-processing required to compute the final segmentation mask as well as the overlay.

def run_inference(images, model):
    
    for img in images:
        
        # Forward pass through the model (convert the tensor output to a numpy array).
        pred_mask = model.predict(img).numpy()
        
        # Remove the background class added by the model.
        pred_mask = pred_mask[:,:,:,1:]
        
        # Remove the batch dimension.
        pred_mask = np.squeeze(pred_mask)
        
        # `pred_mask` is a numpy array of shape [H, W, 32] where each channel contains the probability  
        # scores associated with a given class. We still need to assign a single class to each pixel 
        # which is accomplished using the argmax function across the last dimension to obtain the class labels.
        pred_mask_class = np.argmax(pred_mask, axis=-1)

        # Convert the predicted (class) segmentation map to a color segmentation map.
        pred_mask_rgb = class_to_rgb(pred_mask_class, class_index)
                
        fig = plt.figure(figsize=(20, 15))
        
        # Display the original image.
        ax1 = fig.add_subplot(1,3,1)
        ax1.imshow(img[0])
        ax1.title.set_text('Input Image')
        plt.axis('off')

        # Display the predicted color segmentation mask. 
        ax2 = fig.add_subplot(1,3,2)
        ax2.set_title('Predicted Mask')
        ax2.imshow(pred_mask_rgb)
        plt.axis('off')

        # Display the predicted color segmentation mask overlayed on the original image.
        overlayed_image = image_overlay(img[0], pred_mask_rgb)
        ax4 = fig.add_subplot(1,3,3)
        ax4.set_title('Overlayed Image')
        ax4.imshow(overlayed_image)
        plt.axis('off')
        
        plt.show()

plot_color_legend()

The function plot_color_legend() creates a color legend for the CamVid dataset, which is helpful for confirming the class assignments by the model.

def plot_color_legend(class_index):
    
    # Extract colors and labels from class_index dictionary.
    color_array = np.array([[v[0][0], v[0][1], v[0][2]] for v in class_index.values()]).astype(np.uint8)
    class_labels = [val[1] for val in class_index.values()]    
   
    fig, ax = plt.subplots(nrows=2, ncols=16, figsize=(20, 3))
    plt.subplots_adjust(wspace = 0.5, hspace=0.01)
    
    # Display color legend.
    for i, axis in enumerate(ax.flat):

        axis.imshow(color_array[i][None, None, :])
        axis.set_title(class_labels[i], fontsize = 8)
        axis.axis('off')
plot_color_legend(class_index)
Camvid color legend

Make Predictions on the Sample Images

Now, let’s use this function to perform inference on the sample images using the three models we selected above.

run_inference(images, seg_model)
Camvid inference results

Conclusion

In this post, we covered how to use pre-trained image segmentation models available in TensorFlow Hub. TensorFlow Hub simplifies the process of reusing existing models by providing a central repository for sharing, discovering, and reusing pre-trained machine learning models. An essential aspect of working with these models involves comprehending the process of interpreting their output. Image segmentation models produce multi-channel segmentation masks, which consist of probability scores that require further processing to generate the final segmentation maps.

 

Get Started with TensorFlow

This course is available for FREE only till 22nd Nov.
FREE Python Course
We have designed this Python course in collaboration with OpenCV.org for you to build a strong foundation in the essential elements of Python, Jupyter, NumPy and Matplotlib.
FREE OpenCV Crash Course
We have designed this FREE crash course in collaboration with OpenCV.org to help you take your first steps into the fascinating world of Artificial Intelligence and Computer Vision. The course will be delivered straight into your mailbox.
 

Get Started with OpenCV

Subscribe to receive the download link, receive updates, and be notified of bug fixes

seperator

Which email should I send you the download link?

Subscribe To Receive
We hate SPAM and promise to keep your email address safe.​
Subscribe Now
About LearnOpenCV

Empowering innovation through education, LearnOpenCV provides in-depth tutorials, code, and guides in AI, Computer Vision, and Deep Learning. Led by Dr. Satya Mallick, we're dedicated to nurturing a community keen on technology breakthroughs.

Copyright © 2024 – BIG VISION LLC Privacy Policy Terms and Conditions