Smashing Automatic Speech Recognition Models with C =================================================== This tutorial demonstrates how to use the `pruna` package to optimize any custom whisper model. In this case the outputted model is an optimized whisper model. We will use the openai/whisper-large-v3 model as an example. Loading the ASR model ---------------------------------- First, load your asr model. .. code-block:: python import torch from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline from datasets import load_dataset device = "cuda:0" if torch.cuda.is_available() else "cpu" torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 model_id = "openai/whisper-large-v3" model = AutoModelForSpeechSeq2Seq.from_pretrained( model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True ) model.to(device) processor = AutoProcessor.from_pretrained(model_id) Initializing the Smash Config ------------------------------- Next, initialize the smash_config. .. code-block:: python from pruna_engine.SmashConfig import SmashConfig # Initialize the SmashConfig smasher_config = SmashConfig() smasher_config['task'] = 'audio_text_transcription' smasher_config['compilers'] = 'c_whisper' smasher_config['processor'] = processor #uncomment the following line to quantize the model to 8 bits # smasher_config['weight_quantization_bits'] = 8 Smashing the Model ------------------ Now, smash the model. .. code-block:: python from pruna.smash import smash # Smash the model smashed_model = smash( model=model, api_key='', # replace with your actual API key smash_config=smash_config, ) Don't forget to replace the api_key by the one provided by PrunaAI. Preparing the Input ------------------- .. code-block:: bash dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation") sample = dataset[0]["audio"] input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features.cuda().half() prompt = processor.get_decoder_prompt_ids(language="english", task="transcribe") Running the Model ----------------- Finally, run the model to transcribe the audio file. .. code-block:: python # Display the result results = smashed_model(input_features) processor.decode(results, skip_special_tokens=False) Wrap Up --------- Congratulations! You have successfully smashed an ASR model. You can now use the `pruna` package to optimize any custom ASR model. The only parts that you should modify are step 1 and step 5 to fit your use case.