Quantize and speedup any LLM
This tutorial demonstrates how to use the pruna
package to optimize both the latency and the memory footprint of any LLM from the diffusers package.
We will use the meta-llama/Llama-3.2-1b-Instruct
model as an example, but this tutorial is working on any language model.
We show here results with hqq
quantizer, but this tutorial is working with gptq
, llm_int8
, and higgs``(need ``pruna_pro
for this one).
1. Loading the LLM
First, load your LLM and its associated tokenizer.
[ ]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "meta-llama/Llama-3.2-1b-Instruct"
# We observed better performance with bfloat16 precision.
model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="cuda",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
2. Test the original model speed
[ ]:
import time
# Warmup the model
for _ in range(3):
with torch.no_grad():
inp = tokenizer(["This is a test of this large language model"], return_tensors="pt")
input_ids = inp['input_ids'].cuda()
generated_ids = model.generate(input_ids, max_length=input_ids.shape[1] + 56, min_length=input_ids.shape[1] + 56)
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
torch.cuda.synchronize()
t = time.time()
with torch.no_grad():
inp = tokenizer(["This is a test of this large language model"], return_tensors="pt")
input_ids = inp['input_ids'].cuda()
generated_ids = model.generate(input_ids, max_length=input_ids.shape[1] + 56, min_length=input_ids.shape[1] + 56)
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(text)
torch.cuda.synchronize()
print(time.time() - t)
3. Initializing the Smash Config
Next, initialize the smash_config (we make use, here, of the hqq-diffusers and torch-compile algorithms).
[ ]:
from pruna import SmashConfig
smash_config = SmashConfig()
# Select the quantizer
smash_config['quantizer'] = 'hqq'
smash_config['hqq_weight_bits'] = 4 # can work with 2, 8 also (but 4 is the best performance)
smash_config['hqq_compute_dtype'] = 'torch.bfloat16' # can work with float16, but better performance with bfloat16
# Select torch_compile for the compilation
smash_config['compiler'] = 'torch_compile'
# smash_config['torch_compile_max_kv_cache_size'] = 400 # uncomment if you want to use a custom kv cache size
smash_config['torch_compile_fullgraph'] = True
smash_config['torch_compile_mode'] = 'max-autotune'
# If the model is not compatible with cudagraphs, you can try to comment the line above and uncomment the line below
# smash_config['torch_compile_mode'] = 'max-autotune-no-cudagraphs'
3. Smashing the Model
Now, smash the model. This can take up to 30 seconds.
[ ]:
from pruna import smash
# Smash the model
pipe = smash(
model=model,
smash_config=smash_config,
)
4. Running the Model
Finally, run the model to generate the text you want. Note we need a small warmup the first time we run it (< 1 minute).
NB: Currently the quantized+compiled LLM only support the default sampling strategy, and you need to generate tokens following model.generate(input_ids, max_new_tokens=X)
, where X is the number of tokens you want to produce. We plan to support other sampling schemes (dola, contrastive, etc.) in the near future.
[ ]:
import time
# Warmup the model
for _ in range(3):
with torch.no_grad():
inp = tokenizer(["This is a test of this large language model"], return_tensors="pt")
input_ids = inp['input_ids'].cuda()
generated_ids = model.generate(input_ids, max_new_tokens=56)
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
torch.cuda.synchronize()
t = time.time()
with torch.no_grad():
inp = tokenizer(["This is a test of this large language model"], return_tensors="pt")
input_ids = inp['input_ids'].cuda()
generated_ids = model.generate(input_ids, max_new_tokens=56)
text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
print(text)
torch.cuda.synchronize()
print(time.time() - t)
Wrap Up
Congratulations! You’ve optimized your LLM using HQQ quantization and TorchCompile!
The quantized model uses less memory and runs faster while maintaining good quality.
You can try other quantizers from pruna
(gptq
, llm_int8
), or higgs
quantizer from pruna_pro
(this one provides speedups also for batch inference and can maintain quality at low bit levels).
Want more optimization techniques? Check out our other tutorials!