Whisper is a leading open-source model used for converting speech to text. Developed by OpenAI, Whisper has been trained on a diverse array of languages and speech conditions using extensive data. Despite its advanced capabilities, Whisper sometimes struggles with transcribing new or unusual types of audio. To address this, it’s essential to fine-tune the model on specific datasets. In this article, we’ll guide you through the process of fine tuning Whisper using a custom dataset tailored for Automatic Speech Recognition in Air Traffic Control. This not only enhances transcription accuracy but also boosts safety in real-world operations.
By the end of this article, you’ll know how to adjust Whisper to work effectively with your own data. Additionally, we’ll explore how different versions of Whisper – Tiny, Base, and Small – perform in terms of processing speed and accuracy with practical examples from our tests.
We aim to make the smaller versions of Whisper performant on a specific audio dataset where the pretrained models perform poorly. In this article, we will explore the following topics in detail while fine tuning Whisper Small model:
- Understanding the dataset – We will use the Air Traffic Control audio dataset for fine tuning Whisper.
- Going through the fine tuning code – We will start with the dataset preparation and eventually fine-tune the Whisper model on the custom dataset.
- Comparing error rates and inference time – We will use Word Error Rate (WER) as the benchmark metric and also compare the inference time for each of the fine tuned Whisper models.
Table of Contents
- The Air Traffic Control Dataset
- Whisper Models for Fine Tuning on the Air Traffic Control Dataset
- Fine Tuning Whisper on Custom Dataset for Air Traffic Control Audio
- Comparing Fine Tuned Whisper Tiny, Base, and Small Models
- Inference Time Comparison of the Fine Tuned Whisper Models
- Gradio UI for Whisper Fine Tuning
- Key Takeaways
- Conclusion
The Air Traffic Control Dataset
The atco2-asr-atcosim dataset will be used for fine tuning Whisper. It is a combination of two different Air Traffic Control (ATC) operator speech audios.
Following is a sample audio from the validation set.
As you may realize, these operator speech audios are extremely difficult to comprehend for the average listener. Although most of the conversations happen between the operator and the pilots, there are several reasons why an ASR (Automatic Speech Recognition) system for transcribing these audios can be beneficial.
- Increased safety and performance.
- Whether the instructions were clearly understood and followed or not.
- For training the future workforce as well.
Using pretrained Whisper models on these audio files for ASR does not produce good results. This calls for the fine tuning of the models for better results.
Next let’s take a brief overview of the models that we will use for fine-tuning.
Whisper Models for Fine Tuning on the Air Traffic Control Dataset
We want the results not only to be accurate but also to be fast as well. For this reason, we will not use the Whisper Medium and Large Models, rather, we will focus on three smaller variants: Whisper Tiny, Base, and Small.
The following table covers the details of the above three models.
Size | Parameters | English-only model | Multilingual model | Required VRAM | Relative speed |
tiny | 39 M | tiny.en | tiny | ~1 GB | ~32x |
base | 74 M | base.en | base | ~1 GB | ~16x |
small | 244 M | small.en | small | ~2 GB | ~6x |
As we can see, the required VRAM for each of the models is within acceptable limits, even for real-time deployment. In fact, the Tiny and Base models are quite fast on the CPU as well.
Fine Tuning Whisper on Custom Dataset for Air Traffic Control Audio
From this section onward, we will focus on the training code for fine tuning Whisper the custom dataset. Although we have three different models for fine-tuning, we will cover the code for the Whisper Small model only. All the Jupyter Notebooks, trained weights, and inference scripts are downloadable via the Download section.
Before we dive into the training code, the following are the libraries that we need:
- datasets
- transformers
- accelerate
- evaluate
- jiwer
- tensorboard
- gradio
We are building on top of the PyTorch framework. All the necessary dependencies will be installed via the first code cell of the Jupyter Notebook.
The code here is present in the fine_tune_whisper_small_atco2.ipynb
notebook.
Let’s start with the installation of all the necessary components.
!pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio
The Import Statements
Next, import the below libraries and modules we will need throughout the notebook.
from datasets import load_dataset, DatasetDict
from transformers import (
WhisperTokenizer,
WhisperProcessor,
WhisperFeatureExtractor,
WhisperForConditionalGeneration,
Seq2SeqTrainingArguments,
Seq2SeqTrainer
)
from datasets import Audio
from dataclasses import dataclass
from typing import Any, Dict, List, Union
import torch
import evaluate
Let’s cover some of the important ones in detail:
datasets
: To load the Air Traffic Control dataset.WhisperTokenizer
: To tokenize the ground truth transcriptions of the audio files.WhisperFeatureExtractor
: This is used to extract the features from the audio arrays.WhisperProcessor
: This class combines both, the Whisper Feature Extractor and the Whisper Tokenizer for easier access.WhisperForConditionalGeneration
: To load the pretrained Whisper model.Seq2SeqTrainingArguments
andSeq2SeqTrainer
: To define the traning arguments and the Trainer API respectively.
Defining Some Constants and Training Parameters
Now, let’s define some of the training parameters we will need later in the notebook.
model_id = 'openai/whisper-small'
out_dir = 'whisper_small_atco2'
epochs = 10
batch_size = 32
The model_id
is the pretrained repository ID from Hugging Face. out_dir
is the output directory where the training artifacts will be stored. We will train the model for 10 epochs with a batch size of 32.
Loading the Dataset, Tokenizer, and Feature Extractor
Dataset preparation is one of the most important aspects of starting the fine tuning process.
The first step is to load the training and validation sets from Hugging Face datasets.
atc_dataset_train = load_dataset('jlvdoorn/atco2-asr-atcosim', split='train')
atc_dataset_valid = load_dataset('jlvdoorn/atco2-asr-atcosim', split='validation')
If the above two datasets are printed, then we get the following output.
There are 8092 training samples and 2026 validation samples. Furthermore, there are three features: audio, text, and info.
Let’s check what one sample from the training set contains.
The audio
feature contains the audio file path and the 1-D array of the audio file in numeric format along with the sampling rate. The text
feature is the ground truth transcription. The info
feature contains any additional information. However, we won’t need the info
feature during the fine tuning process.
The sampling rate is one of the most critical components to consider while fine tuning Whisper. All the Whisper models have been pretrained with a sampling rate of 16KHz. Here, the dataset audio files are already 16KHz. If that was not the case, we should always convert them into the required frequency.
Next, load the feature extractor, tokenizer, and the Whisper processor.
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)
tokenizer = WhisperTokenizer.from_pretrained(model_id, language='English', task='transcribe')
processor = WhisperProcessor.from_pretrained(model_id, language='English', task='transcribe')
When loading the tokenizer and processor, we define the target transcript language and task. Since all the transcriptions are in English, we need to teach the model to transcribe them correctly.
Preprocessing the Dataset
Preprocessing f the dataset to the proper format is necessary before we can feed it to the training pipeline.
The foremost step here is handling the sampling rate of the audio files. As discussed earlier, Whisper models can only accept audio files with 16000 Hertz sampling rate. Although our dataset already contains samples with the same sampling rate, it is still a good idea to have a processing pipeline to ensure that.
atc_dataset_train = atc_dataset_train.cast_column('audio', Audio(sampling_rate=16000))
atc_dataset_valid = atc_dataset_valid.cast_column('audio', Audio(sampling_rate=16000))
We use the cast_column
method to modify the audio
feature. To modify the sampling rate, use the Audio
class that was imported earlier.
The next step involves extracting the audio features and tokenizing the transcriptions. We can write a helper function and then map the datasets for this.
def prepare_dataset(batch):
audio = batch['audio']
batch['input_features'] = feature_extractor(audio['array'], sampling_rate=audio['sampling_rate']).input_features[0]
batch['labels'] = tokenizer(batch['text']).input_ids
return batch
atc_dataset_train = atc_dataset_train.map(
prepare_dataset,
num_proc=4
)
atc_dataset_valid = atc_dataset_valid.map(
prepare_dataset,
num_proc=4
)
The final step in the dataset processing is the data collator.
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
decoder_start_token_id: int
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_features = [{'input_features': feature['input_features']} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors='pt')
label_features = [{'input_ids': feature['labels']} for feature in features]
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors='pt')
labels = labels_batch['input_ids'].masked_fill(labels_batch.attention_mask.ne(1), -100)
if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
labels = labels[:, 1:]
batch['labels'] = labels
return batch
The data collator performs the following steps:
- Converting the audio input features and the tokenized transcriptions to PyTorch tensors.
- Changing all the masked tokens to -100 so the loss function ignores them while loss calculation.
Loading the Whisper Model
Loading the pretrained Whisper model is straightforward.
model = WhisperForConditionalGeneration.from_pretrained(model_id)
model.generation_config.task = 'transcribe'
model.generation_config.forced_decoder_ids = None
Additionally, we need to create an instance of the data collator that will complete all the configuration steps before fine tuning.
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
)
Defining the Evaluation Metric
We will use WER (Word Error Rate) as the evaluation metric. This is one of the most common metrics for evaluation of speech recognition models.
metric = evaluate.load('wer')
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {'wer': wer}
You can refer to this Hugging Face space to learn more about WER.
Defining Training Arguments, Trainer API, and Starting Training
Let’s define the training arguments with the necessary parameters.
training_args = Seq2SeqTrainingArguments(
output_dir=out_dir,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
gradient_accumulation_steps=1,
learning_rate=0.00001,
warmup_steps=1000,
bf16=True,
fp16=False,
num_train_epochs=epochs,
evaluation_strategy='epoch',
logging_strategy='epoch',
save_strategy='epoch',
predict_with_generate=True,
generation_max_length=225,
report_to=['tensorboard'],
load_best_model_at_end=True,
metric_for_best_model='wer',
greater_is_better=False,
dataloader_num_workers=8,
save_total_limit=2,
lr_scheduler_type='constant',
seed=42,
data_seed=42
)
Above, bf16=True
and fp16=False
have been defined as the training experiment was run on an RTX GPU. In case you run the training on T4 or P100 GPU, be sure to use fp16=True
instead as they don’t support BF16 data type.
The next code block contains the initialization of the Trainer API.
trainer = Seq2SeqTrainer(
args=training_args,
model=model,
train_dataset=atc_dataset_train,
eval_dataset=atc_dataset_valid,
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
Finally, we can call the train
method of trainer
to start the training process.
trainer.train()
Note: The Whisper Small fine-tuning was run on a 48 GB RTX A6000 GPU to accommodate the GPU memory considering a 32 batch size. The Whisper Tiny and Base tunings were run on RTX 4090 GPUs.
Following are the logs from the above training run.
We get the best WER of 3.15 on Epoch 7. The model saved from this epoch will be used for inference.
Comparing Fine Tuned Whisper Tiny, Base, and Small Models
From the different fine tuning experiments, we obtain the following table. This depicts:
- The best WER from each fine-tuning run.
- The ground truth and inference results of the first three samples from the validation dataset for each of the best Whisper models.
Needless to say, the Whisper Small model gives the best overall results, though erroneous in some places. It’s interesting to see that before fine tuning, the Whisper models were throwing garbage transcriptions and the fine tuning process brought them extremely close to the ground truth. This sheds some light on the fact that we can use open-source AI models to build real-life applications, but of course it should be followed with the right strategy.
Inference Time Comparison of the Fine Tuned Whisper Models
It is also important to note that for real-world deployment, the inference time matters, which in-turn affects the deployment cost. The following script reads three audio files from the inference_data
directory and computes the average inference time of each model.
The following code is present in the compare_time.py
script.
Note: The inference time comparison runs were performed on a laptop RTX 3070 Ti GPU.
"""
Script to compare time for fine-tuned Whisper models.
"""
import torch
import time
import os
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
model_dirs = [
'whisper_tiny_atco2_v2/best_model',
'whisper_base_atco2/best_model',
'whisper_small_atco2/best_model'
]
input_dir = 'inference_data'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
for model_id in model_dirs:
print(f"\nEvaluating model: {model_id}")
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
use_safetensors=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
'automatic-speech-recognition',
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=torch_dtype,
device=device
)
total_time = 0
num_runs = 0
for _ in range(10):
for filename in os.listdir(input_dir):
if filename.endswith('.wav'):
start_time = time.time()
result = pipe(os.path.join(input_dir, filename))
end_time = time.time()
total_time += (end_time - start_time)
num_runs += 1
average_time = total_time / num_runs
print(f"\nAverage time taken for {model_id}: {average_time} seconds")
As the first run always takes more time (because of GPU warmup), we run each model 10 times through each of the three audio files.
The following graph shows the average inference time for each fine tuned Whisper model.
The inference run time for the fine tuned Whisper Tiny model is the least. At the moment, it may not look like much as it has the highest WER among all the fine tuning runs. However, if we can scale up the dataset, we can easily bring down the WER and that will make the Whisper Tiny model a great choice for real-time deployment.
Gradio UI for Whisper Fine Tuning
We have also prepared a simple Gradio demo for a smoother inference process using the fine tuned Whisper models. The code is in the gradio_ui.py
script.
from transformers import pipeline
import gradio as gr
pipe = pipeline(
model='whisper_small_atco2/best_model',
tokenizer='whisper_small_atco2/best_model',
task='automatic-speech-recognition',
device='cuda'
)
def transcribe(audio):
text = pipe(audio)['text']
return text
iface = gr.Interface(
fn=transcribe,
inputs=gr.Audio(sources=['microphone', 'upload'], type='filepath'),
outputs='text'
)
iface.launch(share=True)
The following demo shows uploading an audio file, submitting it, and getting the result in the output text box.
You can switch the model
with the fine tuned Tiny or Base model as well.
Key Takeaways
- Whisper is an extremely powerful ASR model that can prove valuable in building real-life applications. Fine tuning Whisper unlocks its potential for unique use cases.
- We can fine-tune smaller versions of Whisper for excellent performance, although size and quality of the dataset is key here.
- Fine-tuning ASR models demand high computational power, mostly in the range of 20-40 GB of VRAM for optimal training time.
Conclusion
In this article, we covered the process of fine tuning Whisper on a custom Air Traffic Control dataset. Starting with an explanation of the dataset, and concluding with a comparison and inference of runtime. Such fine-tuned models can be integrated with cloud deployment to provide end users with a real-time interface for Speech to Text data.
Let us know in the comment section if you are moving forward with this project and building an interesting application.