Transcribe 2 hours of audio in less than 2 minutes with Whisper
This tutorial demonstrates how to use the pruna
package to optimize any custom whisper model. In this case, the smash function wraps the model into an efficient pipeline, which will transcribe 2 hours of audio in under 2 minutes on an A100 GPU We will use the openai/whisper-large-v3
model as an example.
1. Loading the ASR model
First, load your ASR model.
[ ]:
import torch
from transformers import AutoModelForSpeechSeq2Seq
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, use_safetensors=True, low_cpu_mem_usage=True,
)
model.to(device)
2. Initializing the Smash Config
Next, initialize the smash_config. Since the compilers require a processor, we add it to the smash_config.
[ ]:
from pruna import SmashConfig
# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config.add_processor(model_id)
smash_config['compilers'] = ['ws2t', 'cwhisper']
# uncomment the following line to quantize the model to 8 bits
# smash_config['comp_cwhisper_weight_bits'] = 8
3. Smashing the Model
Now, smash the model. This will take approximately 2 minutes on a T4 GPU. Don’t forget to replace the token by your Pruna token.
[ ]:
from pruna import smash
# Smash the model
smashed_model = smash(
model=model,
token='<your_token>', # replace <your-token> with your actual token or set to None if you do not have one yet
smash_config=smash_config,
)
4. Preparing the Input
[ ]:
import requests
response = requests.get("https://huggingface.co/datasets/reach-vb/random-audios/resolve/main/sam_altman_lex_podcast_367.flac")
audio_sample = 'sam_altman_lex_podcast_367.flac'
# Save the content to the specified file
with open(audio_sample, 'wb') as f:
f.write(response.content)
5. Running the Model
Finally, run the model to transcribe the audio file. Make sure you have ffmpeg
installed.
[ ]:
# Display the result
smashed_model(audio_sample)
Wrap Up
Congratulations! You have successfully smashed an ASR model. You can now use the pruna
package to optimize any custom ASR model. The only parts that you should modify are step 1, 4 and 5 to fit your use case.