100% faster Whisper Transcription

Open In Colab

This tutorial demonstrates how to use the pruna package to optimize any custom whisper model. We will use the openai/whisper-large-v3 model as an example.

1. Loading the ASR model

First, load your ASR model.

[ ]:
import torch
from transformers import AutoModelForSpeechSeq2Seq


device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

2. Initializing the Smash Config

Next, initialize the smash_config. Since the compiler requires a processor, we add it to the smash_config.

[ ]:
from pruna import SmashConfig

# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config.add_processor(model_id)
smash_config['compilers'] = 'cwhisper'
# uncomment the following line to quantize the model to 8 bits
# smash_config['comp_cwhisper_weight_bits'] = 8

3. Smashing the Model

Now, you can smash the model, which will take approximately 2 minutes on a T4 GPU. Don’t forget to replace the token by your Pruna token.

[ ]:
from pruna import smash

# Smash the model
smashed_model = smash(
    model=model,
    token='<your_token>',  # replace <your-token> with your actual token or set to None if you do not have one yet
    smash_config=smash_config,
)

4. Preparing the Input

[ ]:
from datasets import load_dataset
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(model_id)

dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features.cuda().half()

5. Running the Model

Finally, run the model to transcribe the audio file.

[ ]:
# Display the result
results = smashed_model(input_features)
processor.decode(results, skip_special_tokens=False)

Wrap Up

Congratulations! You have successfully smashed an ASR model. You can now use the pruna package to optimize any custom ASR model. The only parts that you should modify are step 1, 4 and 5 to fit your use case.