According to World Wildlife Fund assessments, the global biodiversity crisis has reached critical levels, with terrestrial mammal populations declining by 69% since 1970. From Africa’s savannahs to Asia’s forests, animal populations have been affected due to habitat loss, climate change, and poaching. Among the most vulnerable species are African elephants (population decreased 60% in a decade), black rhinos (remaining 6,487 individuals), Cape buffaloes (490,000 left), and plains zebras (150,000-200,000 surviving). The illegal wildlife trade generates $23 billion annually, while human- wildlife conflicts cost African nations $142 million yearly in crop damage and livestock losses. But how can deep learning help here? Well, FineTuning RetinaNet is the answer.

If we elaborate on “How?” We can fine-tune a deep learning model to detect wildlife animals and monitor them effectively, even at night or in extremely dark or restricted areas where humans cannot always be present. This approach can help protect these animals and contribute to biodiversity preservation and wildlife conservation.
In this article, we will finetune RetinaNet on wildlife animal data. We are going to use Pytroch to create our training pipeline. Throughout the tutorial, we will try to explore:
- What is Wildlife Animal Detection? Why Do We Need It?
- Why did we choose RetinaNet?
- What are the Challenges of Solving This Problem?
- Building the Training Pipeline using PyTorch
- Fintuning the RetinaNet Model
- Inference and Comparison with YOLO11
- A Quick Recap of the Article
Let’s get started!
In this tutorial, you’ll learn how to fine-tune RetinaNet using PyTorch for accurate wildlife animal detection, achieving an impressive mean Average Precision (mAP) of 79% on a challenging wildlife dataset. RetinaNet’s unique architecture, featuring Focal Loss and a Feature Pyramid Network, effectively handles common problems like class imbalance and small or obscured animals—common hurdles in wildlife monitoring. Additionally, we’ll benchmark RetinaNet’s performance against YOLO11, highlighting key accuracy improvements and demonstrating practical tips to maximize model effectiveness even with limited GPU resources.
Also, we have given the entire code in the Download Code section, try it on your own!
- Introduction To Wildlife Animal Detection, Why Do We Need It?
- Why We Choose RetinaNet for This?
- Challenges in Fine-Tuning a Wildlife Dataset with RetinaNet
- Preparing the Wildlife Dataset for RetinaNet
- Code Pipeline – FineTuning RetinaNet
- Inference – Finetuned RetinaNet vs Finetuned YOLO11
- Quick Recap
- Conclusion
- References
Introduction To Wildlife Animal Detection, Why Do We Need It?
Before moving to the coding part, let’s understand the current scenario. Traditionally, wildlife monitoring depended on field surveys, camera traps, and aerial surveys. These methods are often slow, costly, and incomplete. Object detection models can help to track species, prevent poaching, and predict habitat loss without requiring constant human supervision.
If we focus on Issues and Challenges in this field:
Slow Response to Poaching:
Without specialized equipment or continuous surveillance, rangers often discover poachers’ tracks hours—or even days—after a killing takes place. This delay increases the risk to both animals and rangers, since it’s harder to catch perpetrators or provide timely medical care to injured wildlife.
Inaccurate Population Monitoring:
Obtaining precise counts of buffaloes, elephants, or other species is time-consuming and prone to human error. By the time researchers process the data, population shifts or health issues may have already occurred, leaving critical conservation opportunities unmet.
High Costs and Operational Risks:
Traditional monitoring methods require large teams, heavy equipment, and extensive travel to remote locations. These efforts are expensive and fraught with dangers like extreme weather, rugged terrain, and encounters with venomous creatures in dense forests.
Here, deep learning comes into play. Finetuning retinanet on a wildlife animal data allows us to stop poachers or other threats faster, stopping the harm before it starts. Whether scanning a small reserve or an entire continent, our finetuned model can monitor more animals in less time and at a lower cost. This capability helps observe animals more effectively, provides them better care and protection, and ultimately boosts biodiversity.
Why We Choose RetinaNet for This?
RetinaNet stands out by combining a ResNet backbone with a Feature Pyramid Network (FPN) to produce multi-scale feature maps, allowing it to catch animals of various sizes in complex environments. It splits detection into two sub-networks: one for classification and one for bounding box regression. This design works hand in hand with Focal Loss, which down-weights easy negatives and spotlights hard positives.
Now, you might wonder, “Isn’t YOLO famous for speed? Why not just use YOLO?” Indeed, YOLO is fast and popular for real-time detection, but RetinaNet remains a strong contender, especially when dealing with:
Class Imbalance: Wildlife datasets often have many images of common backgrounds (like grass or sky) and fewer images of rare species (like rhinos). RetinaNet’s Focal Loss is designed to tackle this imbalance by down-weighting easy negatives (background) and focusing on harder positives (actual animals).
Small or Partially Obscured Animals: In dense forests or tall grass, animals may appear as tiny shapes or partial outlines. RetinaNet’s two-stage feature pyramid approach can better handle various scales, giving it an edge in complex scenes.
Consistency in Real-World Tests: Many field projects choose RetinaNet because it consistently balances accuracy and speed. YOLO might be faster in simpler environments, but RetinaNet often delivers higher precision on challenging wildlife datasets.
After finetuning, we will run the inference on our finetuned retinanet and compare it with finetuned YOLO11. It will give us a good idea about how our model performs on new data.
Challenges in Fine-Tuning a Wildlife Dataset with RetinaNet
Even though RetinaNet is suitable for wildlife tasks, we face two significant problems when training on limited wildlife data:
Availability of Data
Scarce Wildlife Images: Endangered species like rhinos or rare buffalo variants often have very few labeled examples. This shortage makes it risky to train RetinaNet entirely from scratch.
Relying on Pretrained Models: As we have less amount of data, we use the pretrained RetinaNet model from the torchvision. While this approach helps in our case, it can lead to domain mismatches if our field images differ significantly from the original COCO dataset.
Hardware Constraints
Cheaper and Accessible GPU Power: We use a laptop GPU (RTX 3070 Ti) instead of expensive options like A100 or H100 to keep costs low and make AI more accessible. You can download the code and play with it on your laptop as well!
Careful Optimization: We focus on model compression and efficient data handling to avoid slow training and memory issues. This ensures smooth performance; even on limited hardware like edge devices, the model runs without any issues.
Enough theory? Let’s dive into the code!
Preparing the Wildlife Dataset for RetinaNet
We are using the African Wildlife Dataset available in Ultralytics. This dataset showcases four common animal classes typically found in South African nature reserves. It includes images of African wildlife such as buffalo, elephant, rhino, and zebra.
The Dataset Structure:
.
├── test
│ ├── images
│ └── labels
├── train
│ ├── images
│ └── labels
└── valid
├── images
└── labels
For each image, we have a corresponding label (.txt) file. We have split the data into three main folders:
Training set (1052 images): This subset is the largest, comprising 70% of the total dataset. It includes a wide range of images for each class, ensuring comprehensive coverage of the variations and scenarios in which the objects might be encountered. This extensive training set is crucial for the model to learn the distinctive features of each class.
Validation set (225 images): Making up 15% of the dataset, the validation set is used during the model training process to evaluate its performance on data it hasn’t seen during training. This helps in tuning the model’s parameters and early overfitting detection.
Test set (227 images): Also constituting 15% of the dataset, the test set is utilized post-training to assess the model’s generalization capability on completely unseen data. This is critical for understanding how the model will perform in real-world scenarios.
The labels are now formatted in YOLO format:
0 0.819167 0.593750 0.148333 0.242500
0 0.747500 0.472500 0.221667 0.190000
class_id x_center y_center width height
But in PyTorch detection models, the pretrained retinanet expects the format:
0 0.133333 0.285578 0.640104 0.938872
class_id x_min y_min x_max y_max
So, we have converted the data to the required format. You don’t need to do that again manually; we have given the processed dataset and the data processing script in the Dowload Code Button below.
Code Pipeline – FineTuning RetinaNet
We will start building the training pipeline for finetuning RetinaNet. We organized the code for better readability. The structure of our pipeline follows:
.
├── app.py
├── config.py
├── custom_utils.py
├── datasets.py
├── model.py
├── train.py
└── inference.py
Let’s dive into more detail and try to see how each script is working.
config.py
import torch
BATCH_SIZE = 8 # Increase / decrease according to GPU memeory.
RESIZE_TO = 640 # Resize the image for training and transforms.
NUM_EPOCHS = 60 # Number of epochs to train for.
NUM_WORKERS = 4 # Number of parallel workers for data loading.
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Training images and labels files directory.
TRAIN_DIR = "data/train"
# Validation images and labels files directory.
VALID_DIR = "data/valid"
# Classes: 0 index is reserved for background.
CLASSES = ["__background__", "buffalo", "elephant", "rhino", "zebra"]
NUM_CLASSES = len(CLASSES)
# Whether to visualize images after crearing the data loaders.
VISUALIZE_TRANSFORMED_IMAGES = True
# Location to save model and plots.
OUT_DIR = "outputs"
In this script, we define the important settings that control our training. We set the batch size, the number of epochs, and the number of worker threads for reading data. RESIZE_TO
decide how large to resize each image. We check if we have a GPU (“cuda”) and, if so, we set the training device to GPU, otherwise, we use the CPU. The file paths for our training and validation data folders are also here.
We include a list of classes, with the first entry "__background__"
for the background class and the others for different animals like buffalo, elephant, rhino, and zebra. NUM_CLASSES
is the length of this list, and that is how our model knows how many classes it should predict. The boolean VISUALIZE_TRANSFORMED_IMAGES
can be turned on later to see how the data augmentation looks. Finally, we set OUT_DIR
as a place where all outputs, like model checkpoints, will be saved.
custom_utils.py
The custom_utils.py script provides helper functions that keep our training code more organized.
Averager and SaveBestModel
class Averager:
"""
A class to keep track of running average of values (e.g. training loss).
"""
def __init__(self):
self.current_total = 0.0
self.iterations = 0.0
def send(self, value):
self.current_total += value
self.iterations += 1
@property
def value(self):
if self.iterations == 0:
return 0
else:
return self.current_total / self.iterations
def reset(self):
self.current_total = 0.0
self.iterations = 0.0
class SaveBestModel:
"""
Saves the model if the current epoch's validation mAP is higher
than all previously observed values.
"""
def __init__(self, best_valid_map=float(0)):
self.best_valid_map = best_valid_map
def __call__(
self,
model,
current_valid_map,
epoch,
OUT_DIR,
):
if current_valid_map > self.best_valid_map:
self.best_valid_map = current_valid_map
print(f"\nBEST VALIDATION mAP: {self.best_valid_map}")
print(f"SAVING BEST MODEL FOR EPOCH: {epoch+1}\n")
torch.save(
{
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
},
f"{OUT_DIR}/best_model.pth",
)
The Averager
class keeps a running total of a value, usually the training loss. Each time we get a loss, we call send
. It then calculates an average by dividing the total by how many times we called it.
The SaveBestModel
class stores the best validation mAP found so far. We measure the mAP after each epoch, and if the new mAP is higher than the best recorded one, we save the model weights. This way, we can always keep the best model, even if later epochs do worse.
Collate Function and Transforms
def collate_fn(batch):
"""
To handle the data loading as different images may have different
numbers of objects, and to handle varying-size tensors as well.
"""
return tuple(zip(*batch))
def get_train_transform():
# We keep "pascal_voc" because bounding box format is [x_min, y_min, x_max, y_max].
return A.Compose(
[
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Rotate(limit=45),
A.Blur(blur_limit=3, p=0.2),
A.MotionBlur(blur_limit=3, p=0.1),
A.MedianBlur(blur_limit=3, p=0.1),
A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.3),
A.RandomScale(scale_limit=0.2, p=0.3),
ToTensorV2(p=1.0),
],
bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
)
def get_valid_transform():
return A.Compose(
[
ToTensorV2(p=1.0),
],
bbox_params={"format": "pascal_voc", "label_fields": ["labels"]},
)
collate_fn
is needed in object detection tasks because each image can have a different number of bounding boxes. This function helps the data loader combine them into a batch without errors.
The transform
functions use Albumentations to apply random flips, rotations, and scaling for training. Validation uses less augmentation because we want a stable measure of model performance.
Visualizations, Saving, and Plotting
def show_tranformed_image(train_loader):
# Displays transformed images with bounding boxes
pass
def save_model(epoch, model, optimizer):
torch.save(
{
"epoch": epoch + 1,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
},
"outputs/last_model.pth",
)
def save_loss_plot(...):
# Creates and saves a plot of training loss
def save_mAP(...):
# Creates and saves a plot of mAP per epoch
show_tranformed_image
helps us see how our training images and bounding boxes look after augmentation, which can reveal mistakes or confirm everything is correct.
The save_model
function writes the model weights and optimizer state to disk so we can resume training or do inference later.
We also have plotting functions to chart the loss or mAP progress, which helps us analyze if the model is still improving.
dataset.py
Dataset Pre-processing
class CustomDataset(Dataset):
def __init__(self, dir_path, width, height, classes, transforms=None):
self.dir_path = dir_path
self.image_dir = os.path.join(self.dir_path, "images")
self.label_dir = os.path.join(self.dir_path, "labels")
self.width = width
self.height = height
self.classes = classes
self.transforms = transforms
# ...
def __getitem__(self, idx):
# 1) read image
# 2) resize
# 3) load bounding boxes
# 4) convert to tensor
# 5) apply transforms
# ...
return image_resized, target
def create_train_dataset(DIR):
return CustomDataset(
dir_path=DIR,
width=RESIZE_TO,
height=RESIZE_TO,
classes=CLASSES,
transforms=get_train_transform()
)
def create_valid_dataset(DIR):
return CustomDataset(
dir_path=DIR,
width=RESIZE_TO,
height=RESIZE_TO,
classes=CLASSES,
transforms=get_valid_transform()
)
Here, the CustomDataset
class does the main work of reading the images and labels from the provided folders, resizing the images to the dimension we set in config.py
, and then converting each bounding box to the correct tensor form.
The label files typically store class IDs and normalized coordinates for each bounding box. We multiply those coordinates by width
and height
so the bounding boxes match the resized image. Then, we apply transforms
, which we set in custom_utils.py
. We skipped the code for length. It’s available in the Download Code section.
The create_train_dataset
and create_valid_dataset
functions simply construct these dataset objects with the right transforms.
DataLoader Creation
def create_train_loader(train_dataset, num_workers=0):
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last=True,
)
return train_loader
def create_valid_loader(valid_dataset, num_workers=0):
valid_loader = DataLoader(
valid_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=num_workers,
collate_fn=collate_fn,
drop_last=True,
)
return valid_loader
We then wrap each dataset in a DataLoader
, which batches the images and calls our collate_fn
from earlier.
We shuffle training data so the model sees a varied order of samples each epoch. Using drop_last=True
is a common practice in detection tasks so that every batch has the same size.
model.py
Before jumping into the main code, let’s visualize the model:
RetinaNet:
(backbone): BackboneWithFPN(
(body): IntermediateLayerGetter(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(...)
(layer2): Sequential(...)
(layer3): Sequential(...)
(layer4): Sequential(...)
)
(fpn): FeaturePyramidNetwork(
(inner_blocks): ModuleList(
(0): Conv2dNormActivation(...)
(layer_blocks): ModuleList(...)
(extra_blocks): LastLevelP6P7(...)
)
)
)
On a high-level overview, RetinaNet has a backbone (uses ResNet) and a Feature Pyramid Network (FPN) that extracts multi-scale features. FPN merges features from different stages of the backbone (such as the early, mid, and deeper layers), ensuring that we capture both fine, high-resolution details and more abstract, low-resolution features.
Once the FPN has generated these multi-scale feature maps, two sub-networks (heads) operate on each scale of the pyramid: one for classifying anchors into objects, the Classification Head, and one for regressing box coordinates, the Regression Head.
(anchor_generator): AnchorGenerator()
(head): RetinaNetHead(
(classification_head): RetinaNetClassificationHead(
(conv): Sequential(
(0): Conv2dNormActivation()
(1): Conv2dNormActivation(...)
(2): Conv2dNormActivation(...)
(3): Conv2dNormActivation(...)
)
(cls_logits): Conv2d(256, 45, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
(regression_head): RetinaNetRegressionHead(
(conv): Sequential(
(0): Conv2dNormActivation()
(1): Conv2dNormActivation()
(2): Conv2dNormActivation()
(3): Conv2dNormActivation()
)
(bbox_reg): Conv2d(256, 36, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
)
(transform): GeneralizedRCNNTransform(
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
Resize(min_size=(800,), max_size=1333, mode='bilinear')
)
)
Classification Head
This sub-network assigns a class probability to each anchor. Specifically, it outputs a score for each possible category plus background for every anchor on each scale. In the original COCO-trained model, the classification head is structured to output 80 classes (plus background if counting that separately). To adapt it to a new dataset, we replace this head with one that produces as many outputs as we have categories, which can be fewer or more.
Regression Head
This parallel sub-network refines each anchor to match the ground-truth bounding boxes. It predicts offsets (delta x, delta y, delta width, delta height) for each anchor. These offsets “shift” and “scale” the anchor into the final bounding box for the object.
The GeneralizedRCNNTransform in RetinaNet applies transformations before passing the input image through the model. It ensures that the input is preprocessed consistently before feature extraction and prediction.
Before we go back to the main code, if you want learn about object detection and deep learning from scratch, do join our free bootcamps to getting started!
Now, back to our code:
import torchvision
import torch
from functools import partial
from torchvision.models.detection import RetinaNet_ResNet50_FPN_V2_Weights
from torchvision.models.detection.retinanet import RetinaNetClassificationHead
from config import NUM_CLASSES
def create_model(num_classes=91):
"""
Creates a RetinaNet-ResNet50-FPN v2 model pre-trained on COCO.
Replaces the classification head for the required number of classes.
"""
model = torchvision.models.detection.retinanet_resnet50_fpn_v2(weights=RetinaNet_ResNet50_FPN_V2_Weights.COCO_V1)
num_anchors = model.head.classification_head.num_anchors
# Replace the classification head
model.head.classification_head = RetinaNetClassificationHead(
in_channels=256, num_anchors=num_anchors, num_classes=num_classes, norm_layer=partial(torch.nn.GroupNorm, 32)
)
return model
We load a pre-trained RetinaNet model trained on the COCO dataset. We replace its classification head with a new one that has the correct number of classes for our dataset. Now, the RetinaNetClassificationHead
will learn to produce num_classes
outputs.
We also keep the num_anchors
setup from the original model, which is often enough for many object detection tasks. By loading the COCO weights, the backbone and much of the detection layers start from a strong baseline, and we only need to fine-tune them for our new wildlife classes.
train.py
Now, we use all the scripts we have created before in our main training script and start the training.
Importing and Setting Up
from config import (
DEVICE,
NUM_CLASSES,
NUM_EPOCHS,
OUT_DIR,
VISUALIZE_TRANSFORMED_IMAGES,
NUM_WORKERS,
RESIZE_TO,
VALID_DIR,
TRAIN_DIR,
)
from model import create_model
from custom_utils import Averager, SaveBestModel, save_model, save_loss_plot, save_mAP
from tqdm.auto import tqdm
from datasets import create_train_dataset, create_valid_dataset, create_train_loader, create_valid_loader
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
import torch
import matplotlib.pyplot as plt
import time
import os
plt.style.use("ggplot")
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
We import all the items we need. We read from our config.py
to keep things consistent. We also bring in the custom utilities like Averager
for tracking the loss and SaveBestModel
to store our best model weights. We use MeanAveragePrecision
from torchmetrics
to measure how precise our model is, especially at different IoU thresholds.
The Training Function
# Function for running training iterations.
def train(train_data_loader, model):
print("Training")
model.train()
# initialize tqdm progress bar
prog_bar = tqdm(train_data_loader, total=len(train_data_loader))
for i, data in enumerate(prog_bar):
optimizer.zero_grad()
images, targets = data
images = list(image.to(DEVICE) for image in images)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
loss_dict = model(images, targets)
losses = sum(loss for loss in loss_dict.values())
loss_value = losses.item()
train_loss_hist.send(loss_value)
losses.backward()
optimizer.step()
# update the loss value beside the progress bar for each iteration
prog_bar.set_description(desc=f"Loss: {loss_value:.4f}")
return loss_value
When we call train
, we set the model into training mode. We loop over each batch in train_data_loader
. We move the images and their bounding box data to the GPU or CPU device. We do a forward pass through the model to get a dictionary of losses, and then we sum them up and backpropagate. optimizer.step()
updates the model weights based on the gradients.
The Validation Function
# Function for running validation iterations.
def validate(valid_data_loader, model):
print("Validating")
model.eval()
# Initialize tqdm progress bar.
prog_bar = tqdm(valid_data_loader, total=len(valid_data_loader))
target = []
preds = []
for i, data in enumerate(prog_bar):
images, targets = data
images = list(image.to(DEVICE) for image in images)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
with torch.no_grad():
outputs = model(images, targets)
# For mAP calculation using Torchmetrics.
#####################################
for i in range(len(images)):
true_dict = dict()
preds_dict = dict()
true_dict["boxes"] = targets[i]["boxes"].detach().cpu()
true_dict["labels"] = targets[i]["labels"].detach().cpu()
preds_dict["boxes"] = outputs[i]["boxes"].detach().cpu()
preds_dict["scores"] = outputs[i]["scores"].detach().cpu()
preds_dict["labels"] = outputs[i]["labels"].detach().cpu()
preds.append(preds_dict)
target.append(true_dict)
#####################################
metric.reset()
metric.update(preds, target)
metric_summary = metric.compute()
return metric_summary
The validate
function puts the model in evaluation mode. We skip gradient calculations by using torch.no_grad()
. For each batch of images, we get the outputs. Then we store them along with the ground-truth targets in lists so that MeanAveragePrecision
can compute the average precision across all images and classes.
Main Training Loop
if __name__ == "__main__":
os.makedirs("outputs", exist_ok=True)
train_dataset = create_train_dataset(TRAIN_DIR)
valid_dataset = create_valid_dataset(VALID_DIR)
train_loader = create_train_loader(train_dataset, NUM_WORKERS)
valid_loader = create_valid_loader(valid_dataset, NUM_WORKERS)
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(valid_dataset)}\n")
# Initialize the model and move to the computation device.
model = create_model(num_classes=NUM_CLASSES)
model = model.to(DEVICE)
print(model)
# Total parameters and trainable parameters.
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")
We begin by creating the training and validation datasets, then wrap each in a DataLoader
that batches and shuffles the data. Next, we build a RetinaNet model with the required number of classes, place it on the GPU (or CPU if necessary), and print the total parameters.
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.01, momentum=0.9, nesterov=True, weight_decay=0.0005)
scheduler = ReduceLROnPlateau(
optimizer,
mode="max", # we want to maximize mAP
factor=0.1, # reduce LR by this factor
patience=8, # wait 3 epochs with no improvement
threshold=0.005, # how much improvement is considered significant
cooldown=1,
)
# To monitor training loss
train_loss_hist = Averager()
# To store training loss and mAP values.
train_loss_list = []
map_50_list = []
map_list = []
# Mame to save the trained model with.
MODEL_NAME = "model"
# Whether to show transformed images from data loader or not.
if VISUALIZE_TRANSFORMED_IMAGES:
from custom_utils import show_tranformed_image
show_tranformed_image(train_loader)
# To save best model.
save_best_model = SaveBestModel()
metric = MeanAveragePrecision()
metric.warn_on_many_detections = False
We choose an SGD optimizer with momentum and weight decay to help prevent overfitting. In this code, we use ReduceLROnPlateau to monitor our validation mAP at IoU=0.5:0.95. If the mAP does not improve by at least threshold=0.005
for patience=8
epochs, the learning rate is reduced by a factor of 0.1
, allowing finer adjustments in later stages of training.
# Training loop.
for epoch in range(NUM_EPOCHS):
print(f"\nEPOCH {epoch+1} of {NUM_EPOCHS}")
# Reset the training loss histories for the current epoch.
train_loss_hist.reset()
# Start timer and carry out training and validation.
start = time.time()
train_loss = train(train_loader, model)
metric_summary = validate(valid_loader, model)
current_map_05_95 = float(metric_summary["map"])
current_map_05 = float(metric_summary["map_50"])
print(f"Epoch #{epoch+1} train loss: {train_loss_hist.value:.3f}")
print(f"Epoch #{epoch+1} mAP: {metric_summary['map']:.3f}")
end = time.time()
print(f"Took {((end - start) / 60):.3f} minutes for epoch {epoch+1}")
train_loss_list.append(train_loss)
map_50_list.append(metric_summary["map_50"])
map_list.append(metric_summary["map"])
# save the best model till now.
save_best_model(model, float(metric_summary["map"]), epoch, "outputs")
# Save the current epoch model.
save_model(epoch, model, optimizer)
# Save loss plot.
save_loss_plot(OUT_DIR, train_loss_list)
# Save mAP plot.
save_mAP(OUT_DIR, map_50_list, map_list)
scheduler.step(current_map_05_95)
print("Current LR:", scheduler.get_last_lr())
During each epoch, we first reset our Averager to track the new epoch’s training loss. We call train(...)
, which iterates over batches from the training loader. For each batch, we compute the losses, backpropagate, and update the weights.
Then we call validate(...)
on the validation loader to measure performance without computing gradients. We gather mAP metrics (both at IoU=0.5 and IoU=0.5:0.95) and log them. If the current mAP is higher than any previous epoch, we save these improved model weights. We also save a plot of the training loss and a plot of the mAP curves.
Finally, we call scheduler.step(current_map_05_95)
, passing our main metric so the learning rate scheduler can decide if it needs to lower the LR.
This entire setup keeps the code clean. The dataset script handles data reading and transformation. The model script sets up the detection architecture. The training file controls the epoch and batch logic, hyperparameters, and scheduling. Meanwhile, the utility scripts handle saving models, plotting results, and providing helper functions like the averaging mechanism.
Now if you look at the training Loop:
EPOCH 1 of 60
EPOCH 1 of 60
Training
Loss: 0.4067: 100%|█████████████████████████████████████████| 159/159 [01:44<00:00, 1.52it/s]
Validating
100%|█████████████████████████████████████████████████████████| 28/28 [00:07<00:00, 3.51it/s]
Epoch #1 train loss: 0.594
Epoch #1 mAP: 0.451
Took 1.884 minutes for epoch 1
BEST VALIDATION mAP: 0.45112985372543335
SAVING BEST MODEL FOR EPOCH: 1
SAVING PLOTS COMPLETE...
SAVING mAP PLOTS COMPLETE...
Current LR: [0.01]
.
.
.
EPOCH 42 of 60
Training
Loss: 0.0988: 100%|█████████████████████████████████████████| 159/159 [02:00<00:00, 1.32it/s]
Validating
100%|█████████████████████████████████████████████████████████| 28/28 [00:08<00:00, 3.12it/s]
Epoch #42 train loss: 0.134
Epoch #42 mAP: 0.790
Took 2.169 minutes for epoch 42
BEST VALIDATION mAP: 0.78984534740448
SAVING BEST MODEL FOR EPOCH: 42
SAVING PLOTS COMPLETE...
SAVING mAP PLOTS COMPLETE...
Current LR: [0.0001]
.
.
.
So, we ran the training for 60 epochs and got the validation mAP of 79%. If you see the Loss and mAP plots:
During training, we tuned all the optimizer and scheduler parameters. It took us approximately 40–50 training runs to achieve this level of accuracy. Now, let’s perform inference using our finetuned RetinaNet model.
Inference – Finetuned RetinaNet vs Finetuned YOLO11
To make the inference more interesting, we compare the results for both YOLO11 and RetinaNet side by side. We used YOLO11L (25.3M parameters) since our finetuned RetinaNet has 36M parameters.
The training log looks like this:
optimizer: 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically...
optimizer: AdamW(lr=0.00125, momentum=0.9) with parameter groups 167 weight(decay=0.0), 174 weight(decay=0.0005), 173 bias(decay=0.0)
Image sizes 640 train, 640 val
Using 4 dataloader workers
Logging results to runs/detect/train
Starting training for 60 epochs...
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
1/60 5.68G 0.9655 1.644 1.391 12 640: 1
Class Images Instances Box(P R mAP50 m
all 225 379 0.452 0.341 0.354 0.19
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
.
.
.
Epoch GPU_mem box_loss cls_loss dfl_loss Instances Size
60/60 5.88G 0.4349 0.3049 0.9714 9 640: 1
Class Images Instances Box(P R mAP50 m
all 225 379 0.924 0.907 0.948 0.806
60 epochs completed in 0.581 hours.
Optimizer stripped from runs/detect/train/weights/last.pt, 51.2MB
Optimizer stripped from runs/detect/train/weights/best.pt, 51.2MB
To make the comparison unbiased, we trained the YOLO model for the same 60 epochs using the best configuration suited for it. Both models expect the image size of (640×640) for training, so we keep it the same.
As a result, we achieved an mAP of 81%.
PS: We have used an Nvidia Geforce RTX 3070 Ti Laptop GPU to run the training and inference.
Now, let’s visualize the results:
Observations
Both models perform well, but predictions of RetinaNet are more accurate than YOLO11 most of the time in our testing. As you can see, YOLO11 most of the time giving a false prediction of elephant in the case of buffalo and rhino.
PS:
- We have used a fixed size of 640, and YOLO11 dynamically resized while inference, which might not be an issue here.
- YOLO11L might need a longer training run (for 100 epochs). We have to stop the training at 60 epochs for the comparison.
We made a gradio app for you to play with the model. Go and try with different inputs, and send us some cool results!
Also, we have shared this code over the Download Code folder you downloaded earlier.
ONNX CPU Inference
Now, we have exported both models in ONNX format to run a CPU inference. Let’s see which one has a better FPS!
We ran inference on the same video for both models, using the same iou_threshold. The logs for both the models:
# finetuned retinanet.onnx
Inference completed. Processed 235 frames in 157.97s at 1.49 fps.
# finetuned yolo11l.onnx
Speed: 13.8ms preprocess, 400.3ms inference, 0.7ms postprocess, 2.50 fps.
Both models are a bit slow but the good part is, we can run the inference in CPU. We have given the onnx model and inference code in Download Code section.
Quick Recap
Real-World Conservation Needs
Increasing threats like poaching and habitat loss demand efficient wildlife detection methods. Traditional surveys are slow and expensive, making modern computer vision solutions essential.
Why RetinaNet?
With Focal Loss and a Feature Pyramid Network, RetinaNet manages class imbalance and scales better. It shines in cases where animals are small, partially hidden, or uncommon.
Training Strategy
We used a COCO-pretrained RetinaNet, applied targeted data augmentations, and tuned our optimizer and scheduler to balance learning speed with accuracy. This setup led us to a solid detection performance on our wildlife dataset.
YOLO11 vs. RetinaNet
In direct comparisons, finetuned RetinaNet often delivered higher precision than YOLO11—especially for more challenging classes like rhinos and buffaloes—showing that focusing on accuracy can be just as important as speed.
Conclusion
By fine-tuning RetinaNet’s advanced features and a well-planned fine-tuning process, we achieved perfect wildlife detection even under challenging conditions. This approach helps conservationists track species faster and more accurately, saving time and resources while protecting Africa’s most vulnerable animals. Download and try the code, make changes, and play with all the parameters. Do let us know if you get a better accuracy!
See you in the next blog! Bye 😀
References
Train PyTorch RetinaNet on Custom Dataset