Fine Tuning Whisper on Custom Dataset

In this article, we fine tune the Whisper ASR model on a custom dataset to recognize Air Traffic Control audio.
Fine Tuning Whisper on Custom Dataset

Whisper is a leading open-source model used for converting speech to text. Developed by OpenAI, Whisper has been trained on a diverse array of languages and speech conditions using extensive data. Despite its advanced capabilities, Whisper sometimes struggles with transcribing new or unusual types of audio. To address this, it’s essential to fine-tune the model on specific datasets. In this article, we’ll guide you through the process of fine tuning Whisper using a custom dataset tailored for Automatic Speech Recognition in Air Traffic Control. This not only enhances transcription accuracy but also boosts safety in real-world operations.

By the end of this article, you’ll know how to adjust Whisper to work effectively with your own data. Additionally, we’ll explore how different versions of Whisper – Tiny, Base, and Small – perform in terms of processing speed and accuracy with practical examples from our tests.

Fine Tuning Whisper on Custom Dataset
Figure 1. Fine Tuning Whisper on Custom Dataset

We aim to make the smaller versions of Whisper performant on a specific audio dataset where the pretrained models perform poorly. In this article, we will explore the following topics in detail while fine tuning Whisper Small model:

  • Understanding the dataset – We will use the Air Traffic Control audio dataset for fine tuning Whisper.
  • Going through the fine tuning code – We will start with the dataset preparation and eventually fine-tune the Whisper model on the custom dataset.
  • Comparing error rates and inference time – We will use Word Error Rate (WER) as the benchmark metric and also compare the inference time for each of the fine tuned Whisper models.

Table of Contents

The Air Traffic Control Dataset

The atco2-asr-atcosim dataset will be used for fine tuning Whisper. It is a combination of two different Air Traffic Control (ATC) operator speech audios. 

Following is a sample audio from the validation set.

As you may realize, these operator speech audios are extremely difficult to comprehend for the average listener. Although most of the conversations happen between the operator and the pilots, there are several reasons why an ASR (Automatic Speech Recognition) system for transcribing these audios can be beneficial.

  • Increased safety and performance.
  • Whether the instructions were clearly understood and followed or not.
  • For training the future workforce as well.

Using pretrained Whisper models on these audio files for ASR does not produce good results. This calls for the fine tuning of the models for better results.

Next let’s  take  a brief overview of the models that we will use for fine-tuning.

Whisper Models for Fine Tuning on the Air Traffic Control Dataset

We want the results not only to be accurate but also to be fast as well. For this  reason, we will not use the Whisper Medium and Large Models, rather, we will focus on three smaller variants: Whisper Tiny, Base, and Small

The following table covers the details of the above three models.

SizeParametersEnglish-only modelMultilingual modelRequired VRAMRelative speed
tiny39 Mtiny.entiny~1 GB~32x
base74 Mbase.enbase~1 GB~16x
small244 Msmall.ensmall~2 GB~6x

As we can see, the required VRAM for each of the models is within acceptable limits, even for real-time deployment. In fact, the Tiny and Base models are quite fast on the CPU as well.

Download Code To easily follow along this tutorial, please download code by clicking on the button below. It's FREE!

Fine Tuning Whisper on Custom Dataset for Air Traffic Control Audio

From this section onward, we will focus on the training code for fine tuning Whisper the custom dataset. Although we have three different models for fine-tuning, we will cover the code for the Whisper Small model only. All the Jupyter Notebooks, trained weights, and inference scripts are downloadable via the Download section.

Before we dive into the training code, the following are the libraries that we need:

  • datasets 
  • transformers 
  • accelerate 
  • evaluate 
  • jiwer 
  • tensorboard 
  • gradio

We are building on top of the PyTorch framework. All the necessary dependencies will be installed via the first code cell of the Jupyter Notebook.

The code here is present in the fine_tune_whisper_small_atco2.ipynb notebook.

Let’s start with the installation of all the necessary components.

!pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio

The Import Statements

Next, import the below libraries and modules we will need throughout the notebook.

from datasets import load_dataset, DatasetDict
from transformers import (
    WhisperTokenizer,
    WhisperProcessor,
    WhisperFeatureExtractor,
    WhisperForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer
)
from datasets import Audio
from dataclasses import dataclass
from typing import Any, Dict, List, Union

import torch
import evaluate

Let’s cover some of the important ones in detail:

  • datasets: To load the Air Traffic Control dataset.
  • WhisperTokenizer: To tokenize the ground truth transcriptions of the audio files.
  • WhisperFeatureExtractor: This is used to extract the features from the audio arrays.
  • WhisperProcessor: This class combines both, the Whisper Feature Extractor and the Whisper Tokenizer for easier access.
  • WhisperForConditionalGeneration: To load the pretrained Whisper model.
  • Seq2SeqTrainingArguments and Seq2SeqTrainer: To define the traning arguments and the Trainer API respectively.

Defining Some Constants and Training Parameters

Now, let’s define some of the training parameters we will need later in the notebook.

model_id = 'openai/whisper-small'
out_dir = 'whisper_small_atco2'
epochs = 10
batch_size = 32

The model_id is the pretrained repository ID from Hugging Face. out_dir is the output directory where the training artifacts will be stored. We will train the model for 10 epochs with a batch size of 32.

Loading the Dataset, Tokenizer, and Feature Extractor

Dataset preparation is one of the most important aspects of starting the fine tuning process.

The first step is to load the training and validation sets from Hugging Face datasets.

atc_dataset_train = load_dataset('jlvdoorn/atco2-asr-atcosim', split='train')
atc_dataset_valid = load_dataset('jlvdoorn/atco2-asr-atcosim', split='validation')

If the above two datasets are printed, then we get the following output.

Air Traffic Control dataset samples and features.
Figure 2. Air Traffic Control dataset samples and features.

There are 8092 training samples and 2026 validation samples. Furthermore, there are three features: audio, text, and info.

Let’s check what one sample from the training set contains.

An audio sample and its corresponding transcription from the Air Traffic Control dataset.
Figure 3. An audio sample and its corresponding transcription from the Air Traffic Control dataset.

The audio feature contains the audio file path and the 1-D array of the audio file in numeric format along with the sampling rate. The text feature is the ground truth transcription. The info feature contains any additional information. However, we won’t need the info feature during the fine tuning process.

The sampling rate is one of the most critical components to consider  while fine tuning Whisper. All the Whisper models have been pretrained with a sampling rate of 16KHz. Here, the dataset audio files are already 16KHz. If that was not the case, we should always convert them into the required frequency.

Next, load the feature extractor, tokenizer, and the Whisper processor.

feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)

tokenizer = WhisperTokenizer.from_pretrained(model_id, language='English', task='transcribe')

processor = WhisperProcessor.from_pretrained(model_id, language='English', task='transcribe')

When loading the tokenizer and processor, we define the target transcript language and task. Since all the transcriptions are in English, we need to teach the model to transcribe them correctly.

Preprocessing the Dataset

Preprocessing f the dataset to the proper format is necessary before we can feed it to the training pipeline. 

The foremost step here is handling the sampling rate of the audio files. As discussed earlier, Whisper models can only accept audio files with 16000 Hertz sampling rate. Although our dataset already contains samples with the same sampling rate, it is still a good idea to have a processing pipeline to ensure that.

atc_dataset_train = atc_dataset_train.cast_column('audio', Audio(sampling_rate=16000))
atc_dataset_valid = atc_dataset_valid.cast_column('audio', Audio(sampling_rate=16000))

We use the cast_column method to modify the audio feature. To modify the sampling rate, use the Audio class that was imported earlier.

The next step involves extracting the audio features and tokenizing the transcriptions. We can write a helper function and then map the datasets for this.

def prepare_dataset(batch):
    audio = batch['audio']

    batch['input_features'] = feature_extractor(audio['array'], sampling_rate=audio['sampling_rate']).input_features[0]

    batch['labels'] = tokenizer(batch['text']).input_ids

    return batch

atc_dataset_train = atc_dataset_train.map(
    prepare_dataset,
    num_proc=4
)

atc_dataset_valid = atc_dataset_valid.map(
    prepare_dataset,
    num_proc=4
)

The final step in the dataset processing is the data collator.

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{'input_features': feature['input_features']} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors='pt')

        label_features = [{'input_ids': feature['labels']} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors='pt')

        labels = labels_batch['input_ids'].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch['labels'] = labels

        return batch

The data collator performs the following steps:

  • Converting the audio input features and the tokenized transcriptions to PyTorch tensors.
  • Changing all the masked tokens to -100 so the loss function ignores them while loss calculation.

Loading the Whisper Model

Loading the pretrained Whisper model is straightforward. 

model = WhisperForConditionalGeneration.from_pretrained(model_id)

model.generation_config.task = 'transcribe'

model.generation_config.forced_decoder_ids = None

Additionally, we need to create an instance of the data collator that will complete all the configuration steps before fine tuning.

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

Defining the Evaluation Metric

We will use WER (Word Error Rate) as the evaluation metric. This is one of the most common metrics for evaluation of speech recognition models.

metric = evaluate.load('wer')

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {'wer': wer}

You can refer to this Hugging Face space to learn more about WER.

Defining Training Arguments, Trainer API, and Starting Training

Let’s define the training arguments with the necessary parameters.

training_args = Seq2SeqTrainingArguments(
    output_dir=out_dir,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=1,
    learning_rate=0.00001,
    warmup_steps=1000,
    bf16=True,
    fp16=False,
    num_train_epochs=epochs,
    evaluation_strategy='epoch',
    logging_strategy='epoch',
    save_strategy='epoch',
    predict_with_generate=True,
    generation_max_length=225,
    report_to=['tensorboard'],
    load_best_model_at_end=True,
    metric_for_best_model='wer',
    greater_is_better=False,
    dataloader_num_workers=8,
    save_total_limit=2,
    lr_scheduler_type='constant',
    seed=42,
    data_seed=42
)

Above, bf16=True and fp16=False have been defined as the training experiment was run on an RTX GPU. In case you run the training on T4 or P100 GPU, be sure to use fp16=True instead as they don’t support BF16 data type.

The next code block contains the initialization of the Trainer API.

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=atc_dataset_train,
    eval_dataset=atc_dataset_valid,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

Finally, we can call the train method of trainer to start the training process.

trainer.train()

Note: The Whisper Small fine-tuning was run on a 48 GB RTX A6000 GPU to accommodate the GPU memory considering a 32 batch size. The Whisper Tiny and Base tunings were run on RTX 4090 GPUs.

Following are the logs from the above training run.

Training log after fine tuning Whisper on the Air Traffic Control dataset.
Figure 4. Training log after fine tuning Whisper on the Air Traffic Control dataset.
Validation WER graph after fine tuning Whisper.
Figure 5. Validation WER graph after fine tuning Whisper.

We get the best WER of 3.15 on Epoch 7. The model saved from this epoch will be used for inference.

Comparing Fine Tuned Whisper Tiny, Base, and Small Models

From the different fine tuning experiments, we obtain the following table. This depicts:

  • The best WER from each fine-tuning run.
  • The ground truth and inference results of the first three samples from the validation dataset for each of the best Whisper models.

Comparison of Whisper Small, Base, and Tiny models on three validation samples after fine tuning.
Figure 6. Comparison of Whisper Small, Base, and Tiny models on three validation samples after fine tuning.

Needless to say, the Whisper Small model gives the best overall results, though erroneous in some places. It’s interesting to see that before fine tuning, the Whisper models were throwing garbage transcriptions and the fine tuning process brought them extremely close to the ground truth. This sheds some light on the fact that we can use open-source AI models to build real-life applications, but of course it should be followed  with the right strategy.

Inference Time Comparison of the Fine Tuned Whisper Models

It is also important to note that for real-world deployment, the inference time matters, which in-turn affects the deployment cost. The following script reads three audio files from the inference_data directory and computes the average inference time of each model.

The following code is present in the compare_time.py script.

Note: The inference time comparison runs were performed on a laptop RTX 3070 Ti GPU.

"""
Script to compare time for fine-tuned Whisper models.
"""

import torch
import time
import os

from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline

model_dirs = [
    'whisper_tiny_atco2_v2/best_model',
    'whisper_base_atco2/best_model',
    'whisper_small_atco2/best_model'
]

input_dir = 'inference_data'

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

for model_id in model_dirs:
    print(f"\nEvaluating model: {model_id}")

    model = AutoModelForSpeechSeq2Seq.from_pretrained(
        model_id, torch_dtype=torch_dtype,
        low_cpu_mem_usage=True,
        use_safetensors=True
    )
    model.to(device)

    processor = AutoProcessor.from_pretrained(model_id)

    pipe = pipeline(
        'automatic-speech-recognition',
        model=model,
        tokenizer=processor.tokenizer,
        feature_extractor=processor.feature_extractor,
        torch_dtype=torch_dtype,
        device=device
    )

    total_time = 0
    num_runs = 0

    for _ in range(10):
        for filename in os.listdir(input_dir):
            if filename.endswith('.wav'):
                start_time = time.time()
                result = pipe(os.path.join(input_dir, filename))
                end_time = time.time()
                total_time += (end_time - start_time)
                num_runs += 1

    average_time = total_time / num_runs
    print(f"\nAverage time taken for {model_id}: {average_time} seconds")

As the first run always takes more time (because of GPU warmup), we run each model 10 times through each of the three audio files.

The following graph shows the average inference time for each fine tuned Whisper model.

Inference time graph showing comparison between fine tuned Whisper Small, Base, and Tiny models.
Figure 7. Inference time graph showing comparison between fine tuned Whisper Small, Base, and Tiny models.

The inference run time for the fine tuned Whisper Tiny model is the least. At the moment, it may not look like much as it has the highest WER among all the fine tuning runs. However, if we can scale up the dataset, we can easily bring down the WER and that will make the Whisper Tiny model a great choice for real-time deployment.

Gradio UI for Whisper Fine Tuning

We have also prepared a simple Gradio demo for a smoother inference process using the fine tuned Whisper models. The code is in the gradio_ui.py script.

from transformers import pipeline
import gradio as gr

pipe = pipeline(
    model='whisper_small_atco2/best_model',
    tokenizer='whisper_small_atco2/best_model',
    task='automatic-speech-recognition',
    device='cuda'
) 

def transcribe(audio):
    text = pipe(audio)['text']
    return text

iface = gr.Interface(
    fn=transcribe,
    inputs=gr.Audio(sources=['microphone', 'upload'], type='filepath'),
    outputs='text'
)

iface.launch(share=True)

The following demo shows uploading an audio file, submitting it, and getting the result in the output text box.

Gradio UI demo for Whisper inference.
Figure 8. Gradio UI demo for Whisper inference.

You can switch the model with the fine tuned Tiny or Base model as well.

Key Takeaways

  • Whisper is an extremely powerful ASR model that can prove valuable in building real-life applications. Fine tuning Whisper unlocks its potential for unique use cases.
  • We can fine-tune smaller versions of Whisper for excellent performance, although size and quality of the dataset is key here.
  • Fine-tuning ASR models demand high computational power, mostly in the range of 20-40 GB of VRAM for optimal training time.

Conclusion

In this article, we covered the process of fine tuning Whisper on a custom Air Traffic Control dataset. Starting with an explanation of the dataset, and concluding with a comparison and inference of runtime. Such fine-tuned models can be integrated with cloud deployment to provide end users with a real-time interface for Speech to Text data. 

Let us know in the comment section if you are moving forward with this project and building an interesting application.



Read Next

VideoRAG: Redefining Long-Context Video Comprehension

VideoRAG: Redefining Long-Context Video Comprehension

Discover VideoRAG, a framework that fuses graph-based reasoning and multi-modal retrieval to enhance LLMs' ability to understand multi-hour videos efficiently.

AI Agent in Action: Automating Desktop Tasks with VLMs

AI Agent in Action: Automating Desktop Tasks with VLMs

Learn how to build AI agent from scratch using Moondream3 and Gemini. It is a generic task based agent free from…

The Ultimate Guide To VLM Evaluation Metrics, Datasets, And Benchmarks

The Ultimate Guide To VLM Evaluation Metrics, Datasets, And Benchmarks

Get a comprehensive overview of VLM Evaluation Metrics, Benchmarks and various datasets for tasks like VQA, OCR and Image Captioning.

Subscribe to our Newsletter

Subscribe to our email newsletter to get the latest posts delivered right to your email.

Subscribe to receive the download link, receive updates, and be notified of bug fixes

Which email should I send you the download link?

 

Get Started with OpenCV

Subscribe To Receive

We hate SPAM and promise to keep your email address safe.​