fbpx
Subscribe for More
Subscribe for More
Edit Content
Click on the Edit Content button to edit/add the content.

Object Detection Made Easy with TensorFlow Hub: Step-by-Step Tutorial

In this post, we will learn how to perform object detection with TensorFlow Hub pre-trained models. TensorFlow Hub is a library and platform designed for sharing, discovering, and reusing pre-trained machine learning models. The primary goal of TensorFlow Hub is to simplify the process of reusing existing models, thereby promoting collaboration, reducing redundant work, and accelerating research and development in machine learning. Users can search for pre-trained models, called modules, that have been contributed by the community or provided by Google. These modules can be easily integrated into a user’s own machine learning projects with just a few lines of code.

Object detection is a subfield of computer vision that focuses on identifying and locating specific objects within digital images or videos. It involves not only classifying the objects present in an image but also determining their precise location and size by placing bounding boxes or other spatial encodings around them. In this example, we will use the model EfficientDet/d4, which is from a family of models known as EfficientDet. The pre-trained models from this family available on TensorFlow Hub were all trained on the COCO 2017 dataset. The different models in the family, ranging from D0 to D7, vary in terms of complexity and input image dimensions. D0, the most compact model, accepts input sizes of 512×512 pixels and provides the quickest inference speed. At the other end of the spectrum, we have D7, which requires an input size of 1536×1536 and takes considerably longer to perform inference. Several other object detection models can be found here as well.

import os
import numpy as np
import cv2

import zipfile
import requests
import glob as glob

import tensorflow_hub as hub

import matplotlib
import matplotlib.pyplot as plt

import warnings
import logging
import absl

# Filter absl warnings
warnings.filterwarnings("ignore", module="absl")

# Capture all warnings in the logging system
logging.captureWarnings(True)

# Set the absl logger level to 'error' to suppress warnings
absl_logger = logging.getLogger("absl")
absl_logger.setLevel(logging.ERROR)

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

Download Sample Images

def download_file(url, save_name):
    url = url
    file = requests.get(url)

    open(save_name, 'wb').write(file.content)
def unzip(zip_file=None):
    try:
        with zipfile.ZipFile(zip_file) as z:
            z.extractall("./")
            print("Extracted all")
    except:
        print("Invalid file")
download_file( 
    'https://www.dropbox.com/s/h7l1lmhvga6miyo/object_detection_images.zip?dl=1',
    'object_detection_images.zip'
)
    
unzip(zip_file='object_detection_images.zip')
Extracted all

Display Sample Images

image_paths = sorted(glob.glob('object_detection_images' + '/*.png'))

for idx in range(len(image_paths)):
    print(image_paths[idx])
object_detection_images/dog_bicycle_car.png
object_detection_images/elephants.png
object_detection_images/home_interior.png
object_detection_images/place_setting.png
def load_image(path):

    image = cv2.imread(path)
    
    # Convert image in BGR format to RGB.
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    # Add a batch dimension which is required by the model.
    image = np.expand_dims(image, axis=0)
    
    return image
images = []
fig, ax = plt.subplots(nrows=2, ncols=2, figsize=(20, 15))

idx=0
for axis in ax.flat:
    image = load_image(image_paths[idx])
    images.append(image)
    axis.imshow(image[0])
    axis.axis('off')
    idx+=1
Sample images to use for object detection with TensorFlow Hub.

Define a Dictionary that Maps Class IDs to Class Names

class_index is a dictionary that maps class IDs to class names for the 90 classes in the COCO dataset.

class_index =  \
{
         1: 'person',
         2: 'bicycle',
         3: 'car',
         4: 'motorcycle',
         5: 'airplane',
         6: 'bus',
         7: 'train',
         8: 'truck',
         9: 'boat',
         10: 'traffic light',
         11: 'fire hydrant',
         13: 'stop sign',
         14: 'parking meter',
         15: 'bench',
         16: 'bird',
         17: 'cat',
         18: 'dog',
         19: 'horse',
         20: 'sheep',
         21: 'cow',
         22: 'elephant',
         23: 'bear',
         24: 'zebra',
         25: 'giraffe',
         27: 'backpack',
         28: 'umbrella',
         31: 'handbag',
         32: 'tie',
         33: 'suitcase',
         34: 'frisbee',
         35: 'skis',
         36: 'snowboard',
         37: 'sports ball',
         38: 'kite',
         39: 'baseball bat',
         40: 'baseball glove',
         41: 'skateboard',
         42: 'surfboard',
         43: 'tennis racket',
         44: 'bottle',
         46: 'wine glass',
         47: 'cup',
         48: 'fork',
         49: 'knife',
         50: 'spoon',
         51: 'bowl',
         52: 'banana',
         53: 'apple',
         54: 'sandwich',
         55: 'orange',
         56: 'broccoli',
         57: 'carrot',
         58: 'hot dog',
         59: 'pizza',
         60: 'donut',
         61: 'cake',
         62: 'chair',
         63: 'couch',
         64: 'potted plant',
         65: 'bed',
         67: 'dining table',
         70: 'toilet',
         72: 'tv',
         73: 'laptop',
         74: 'mouse',
         75: 'remote',
         76: 'keyboard',
         77: 'cell phone',
         78: 'microwave',
         79: 'oven',
         80: 'toaster',
         81: 'sink',
         82: 'refrigerator',
         84: 'book',
         85: 'clock',
         86: 'vase',
         87: 'scissors',
         88: 'teddy bear',
         89: 'hair drier',
         90: 'toothbrush'
}

Here we will use COLOR_IDS to map each class with a unique RGB color.

R = np.array(np.arange(96, 256, 32))
G = np.roll(R, 1)
B = np.roll(R, 2)

COLOR_IDS = np.array(np.meshgrid(R, G, B)).T.reshape(-1, 3)

Master Generative AI for CV

Get expert guidance, insider tips & tricks. Create stunning images, learn to fine tune diffusion models, advanced Image editing techniques like In-Painting, Instruct Pix2Pix and many more

Model Inference using Tensorflow Hub

TensorFlow Hub contains many different pre-trained object detection models. Here we will use the EfficientDet class of object detection models that were trained on the COCO 2017 dataset. There are several versions of EfficientDet models. The EfficientDet family of object detectors consists of several models with different levels of complexity and performance, ranging from D0 to D7. The differences between the various models in the EfficientDet family are mainly in their architecture, input image size, computational requirements, and performance.

EfficientDet  = {'EfficientDet D0 512x512'   : 'https://tfhub.dev/tensorflow/efficientdet/d0/1',
                 'EfficientDet D1 640x640'   : 'https://tfhub.dev/tensorflow/efficientdet/d1/1',
                 'EfficientDet D2 768x768'   : 'https://tfhub.dev/tensorflow/efficientdet/d2/1',
                 'EfficientDet D3 896x896'   : 'https://tfhub.dev/tensorflow/efficientdet/d3/1',
                 'EfficientDet D4 1024x1024' : 'https://tfhub.dev/tensorflow/efficientdet/d4/1',
                 'EfficientDet D5 1280x1280' : 'https://tfhub.dev/tensorflow/efficientdet/d5/1',
                 'EfficientDet D6 1280x1280' : 'https://tfhub.dev/tensorflow/efficientdet/d6/1',
                 'EfficientDet D7 1536x1536' : 'https://tfhub.dev/tensorflow/efficientdet/d7/1'
                }

Here we will use the D4 model.

model_url = EfficientDet['EfficientDet D4 1024x1024' ]

print('loading model: ', model_url)
od_model = hub.load(model_url)

print('\nmodel loaded!')
loading model:  https://tfhub.dev/tensorflow/efficientdet/d4/1
Metal device set to: Apple M1 Max

model loaded!

Perform Inference

Before we formalize the code to process several images and post-process the results, let’s first see how to perform inference on a single image and study the output from the model.

Call the Model

# Call the model. # The model returns the detection results in the form of a dictionary.
results = od_model(images[0])

Inspect the Results

The object detection model returns the detection results in the form of a dictionary which includes several different types of keys.

# Convert the dictionary values to numpy arrays.
results = {key:value.numpy() for key, value in results.items()}
# Print the keys from the results dictionary.
for key in results:
    print(key) 
detection_anchor_indices
detection_boxes
detection_classes
detection_multiclass_scores
detection_scores
num_detections
raw_detection_boxes
raw_detection_scores

Notice that the model has several dictionary keys that can be used to access various types of detection data. EfficientDet, like many other object detection models, generates a large number of raw detections (bounding boxes and corresponding class scores) for each input image. Many of these raw detections are redundant, overlapping, or have low confidence scores. To obtain meaningful results, post-processing techniques are applied within the model to filter and refine these raw detections. For our purposes, we are only interested in the detections that have been post-processed within the model, which are available in the dictionary keys that start with detection_.

In the following code cells, we show that there are thousands of raw detections, while there are 16 final detections. Each of these final detections has an associated confidence score which we may want to filter further depending on the nature of our application.

print('Num Raw Detections: ', (len(results['raw_detection_scores'][0])))
print('Num Detections:     ', (results['num_detections'][0]).astype(int))
Num Raw Detections:  196416
Num Detections:      16

Let’s now inspect some of the detection data for all 16 detections. Notice that the detections are sorted from the highest confidence detections to the lowest.

# Print the Scores, Classes and Bounding Boxes for the detections.
num_dets = (results['num_detections'][0]).astype(int)

print('\nDetection Scores: \n\n', results['detection_scores'][0][0:num_dets])
print('\nDetection Classes: \n\n', results['detection_classes'][0][0:num_dets])
print('\nDetection Boxes: \n\n', results['detection_boxes'][0][0:num_dets])
Detection Scores: 

 [0.9053347  0.8789406  0.7202968  0.35475922 0.2805733  0.17851698
 0.15169667 0.14905979 0.14454156 0.13584    0.12682638 0.11745102
 0.10781792 0.10152479 0.10052315 0.09746186]

Detection Classes: 

 [ 2. 18.  8.  3. 64. 64.  2. 18. 64. 64. 64.  4. 64. 44. 64. 77.]

Detection Boxes: 

 [[0.16487242 0.15703079 0.7441227  0.74429274]
 [0.3536     0.16668764 0.9776781  0.40675405]
 [0.06442685 0.61166453 0.25209486 0.8956611 ]
 [0.06630661 0.611912   0.25146762 0.89877594]
 [0.08410528 0.06995308 0.18153256 0.13178551]
 [0.13754636 0.89751065 0.22187063 0.9401711 ]
 [0.34510636 0.16857824 0.97165954 0.40917954]
 [0.18023838 0.15531728 0.7696747  0.7740346 ]
 [0.087889   0.06875686 0.18782085 0.10366233]
 [0.00896974 0.11013152 0.0894229  0.15709913]
 [0.08782443 0.08899567 0.16129945 0.13988526]
 [0.16456181 0.1708141  0.72982967 0.75529355]
 [0.06907014 0.8944937  0.22174956 0.9605442 ]
 [0.30221778 0.10927744 0.33091408 0.15160759]
 [0.11132257 0.09432659 0.16303536 0.12937708]
 [0.133767   0.5592607  0.18178582 0.5844183 ]]

Post-Process and Display Detections

Here we show the logic for how to interpret the detection data for a single image. As we showed above, the model returned 16 detections, however, many detections have low confidence scores, and we, therefore, need to filter these further by using a minimum detection threshold.

  1. Retrieve the detections from the results dictionary
  2. Apply a minimum detection threshold to filter the detections
  3. For each thresholded detection, display the bounding box and a label indicating the detected class and the confidence of the detection.
def process_detection(image, results,  min_det_thresh=.3):

    # Extract the detection results from the results dictionary.
    scores  =  results['detection_scores'][0]
    boxes   =  results['detection_boxes'][0]
    classes = (results['detection_classes'][0]).astype(int)

    # Set a minimum detection threshold to post-process the detection results.
    min_det_thresh = min_det_thresh

    # Get the detections whose scores exceed the minimum detection threshold.
    det_indices = np.where(scores >= min_det_thresh)[0]

    scores_thresh  = scores[det_indices]
    boxes_thresh   = boxes[det_indices]
    classes_thresh = classes[det_indices]

    # Make a copy of the image to annotate.
    img_bbox = image.copy()

    im_height, im_width = image.shape[:2]

    font_scale = .6
    box_thickness = 2

    # Loop over all thresholded detections.
    for box, class_id, score in zip(boxes_thresh, classes_thresh, scores_thresh):

        # Get bounding box normalized coordiantes.
        ymin, xmin, ymax, xmax = box

        class_name = class_index[class_id]

        # Convert normalized bounding box coordinates to pixel coordinates.
        (left, right, top, bottom) = (int(xmin * im_width), 
                                      int(xmax * im_width), 
                                      int(ymin * im_height), 
                                      int(ymax * im_height))

        # Annotate the image with the bounding box.
        color = tuple(COLOR_IDS[class_id % len(COLOR_IDS)].tolist())[::-1]
        img_bbox = cv2.rectangle(img_bbox, (left, top), (right, bottom), color, thickness=box_thickness)

        #-------------------------------------------------------------------
        # Annotate bounding box with detection data (class name and score).
        #-------------------------------------------------------------------

        # Build the text string that contains the class name and score associated with this detection.
        display_txt = '{}: {:.2f}%'.format(class_name, 100 * score)
        ((text_width, text_height), _) = cv2.getTextSize(display_txt, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 1)
        
        # Handle case when the label is above the image frame.
        if top < text_height:
            shift_down = int(2*(1.3*text_height))
        else:
            shift_down = 0
        
        # Draw a filled rectangle on which the detection results will be displayed.
        img_bbox = cv2.rectangle(img_bbox, 
                                 (left-1, top-box_thickness - int(1.3*text_height) + shift_down), 
                                 (left-1 + int(1.1 * text_width), top),               
                                 color, 
                                 thickness=-1)

        # Annotate the filled rectangle with text (class label and score).
        img_bbox = cv2.putText(img_bbox, 
                               display_txt,
                               (left + int(.05*text_width), top - int(0.2*text_height) + int(shift_down/2)),
                               cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), 1)
    return img_bbox

Display Results with min_det_thresh=0

First, let’s process an image using a minimum detection threshold of zero just to see what the model returned for all 16 detections. Since we are not filtering the results, we expect that we may have some redundant and/or false detections.

# Call the model.
results = od_model(images[0])

# Convert the dictionary values to numpy arrays.
results = {key:value.numpy() for key, value in results.items()}

# Remove the batch dimension from the first image.
image = np.squeeze(images[0])

# Process the first sample image.
img_bbox = process_detection(image, results, min_det_thresh=0)

plt.figure(figsize=[15, 10])
plt.imshow(img_bbox)
plt.axis('off');

The results below show all the detections returned by the model since we did not apply a detection threshold to filter them. However, notice that all the mislabeled detections also have very low confidence. It is always therefore recommended to apply a minimum detection threshold to the results generated by the model. The value of the threshold is something you need to experiment with depending on the data and the application, but generally, a value somewhere between 0.3 and 0.5 is a good rule of thumb.

FfficientDet results with zero threshold.

Display Results with min_det_thresh=0.3

Let’s now apply a detection threshold to filter the results.

img_bbox = process_detection(image, results, min_det_thresh=.3)

plt.figure(figsize=[15, 10])
plt.imshow(img_bbox)
plt.axis('off');

Formalize the Implementation

In this section, we will formalize the implementation and create a convenience function to execute the model on a list of images. As noted in the documentation, the models in this family do not support “batching.” This means we need to call the model once for each image. But note that the input shape for the image does require a batch dimension.

run_inference()

run_inference() is a helper function that will call the model for each image in the list of images.

def run_inference(images, model):
    
    results_list = []
    for img in images:
        result = model(img)
        result = {key:value.numpy() for key,value in result.items()}

        results_list.append(result)

    return results_list
# Perform inference on each image and store the results in a list.
results_list = run_inference(images, od_model)

Next, we loop over each of the images and use the results from the model to annotate a copy of the image, which is displayed to the console.

for idx in range(len(images)):
      
    # Remove the batch dimension.
    image = np.squeeze(images[idx])
    
    # Generate the annotated image.
    image_bbox = process_detection(image, results_list[idx], min_det_thresh=.31)
        
    # Display annotated image.
    plt.figure(figsize=[20,10*len(images)])
    plt.subplot(len(images),1,idx+1)
    plt.imshow(image_bbox)
    plt.axis('off')
EfficientDet results dog, bicycle, car
EfficientDet results elephants
EfficientDet results home interior
EfficientDet results place setting

Conclusion

In this post, we covered how to use pre-trained object detection models available in TensorFlow Hub. TensorFlow Hub simplifies the process of reusing existing models by providing a central repository for sharing, discovering, and reusing pre-trained machine learning models. An essential aspect of working with these models involves interpreting their output. A key aspect of this is applying a detection threshold to filter the results generated by the model. Setting an appropriate detection threshold often requires experimentation and will also depend heavily on the type of application. In this example, we used the D4 model from the EfficienDet Family. However, if your application requires faster inference speeds, you should consider a smaller model (D0 to D3).

TensorFlow Hub Resources:



 

Get Started with TensorFlow

Learn the state-of-the-art in AI: DALLE2, MidJourney, Stable Diffusion!
This course is available for FREE only till 22nd Nov.
FREE Python Course
We have designed this Python course in collaboration with OpenCV.org for you to build a strong foundation in the essential elements of Python, Jupyter, NumPy and Matplotlib.
FREE OpenCV Crash Course
We have designed this FREE crash course in collaboration with OpenCV.org to help you take your first steps into the fascinating world of Artificial Intelligence and Computer Vision. The course will be delivered straight into your mailbox.
 

Get Started with OpenCV

Subscribe to receive the download link, receive updates, and be notified of bug fixes

seperator

Which email should I send you the download link?

Subscribe To Receive
We hate SPAM and promise to keep your email address safe.​
Subscribe Now
Disclaimer

All views expressed on this site are my own and do not represent the opinions of OpenCV.org or any entity whatsoever with which I have been, am now, or will be affiliated.

About LearnOpenCV

In 2007, right after finishing my Ph.D., I co-founded TAAZ Inc. with my advisor Dr. David Kriegman and Kevin Barnes. The scalability, and robustness of our computer vision and machine learning algorithms have been put to rigorous test by more than 100M users who have tried our products.

Copyright © 2023 – BIG VISION LLC