Accelerating inference in vLLM serving (Pro)
This tutorial demonstrates how to use the pruna_pro
package to optimize any LLM, and plug it into vLLM
for fast serving.
1. Loading the LLM
First, load your language model, and its tokenizer.
[ ]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_path = "meta-llama/Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_path,
torch_dtype=torch.bfloat16,
device_map="auto",
low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
2. Initializing the Smash Config
Next, initialize the smash_config.
[ ]:
from pruna_pro import SmashConfig, smash
# Initialize the SmashConfig
smash_config = SmashConfig()
# Select the quantizer
smash_config['quantizer'] = 'hqq'
smash_config['hqq_weight_bits'] = 4
smash_config['hqq_compute_dtype'] = 'torch.bfloat16' # can work with float16, but better performance with bfloat16
smash_config['hqq_use_torchao_kernels'] = False # set to False to enable saving
smash_config['hqq_force_hf_implementation'] = True # set to True to bypass the AutoHQQHFModel quantization
3. Smashing the Model
Now, you can smash the model, which can take up to 2 minutes. Don’t forget to replace the token by the one provided by PrunaAI.
[ ]:
model = smash(
model=model,
token="<your_pruna_token>",
smash_config=smash_config,
)
4. Saving the model
vLLM
can not (yet) accept already loaded models. So you have to save your pruna-optimized model to some “”, and then give this path to vLLM
.
[ ]:
smash_config.save_fns = [] # remove the save functions because vLLM expects a model saved with hf from_pretrained fn
model.save_pretrained("path/to/pruna/model")
tokenizer.save_pretrained("path/to/pruna/model")
5. Loading your optimized model with vLLM
Assuming you have pruna_pro
installed in your env (pip install pruna_pro
), you basicaly have only one step to add before serving your LLM with pruna_pro
: we just need to change the “quant_method” name from “hqq” to “hqq_pruna_torchaoint4” (recognized and compatible with vLLM).
[ ]:
from pruna_pro.algorithms.quantization.utils.pruna_to_vllm import patch_quantization_name
patch_quantization_name("path/to/pruna/model", "hqq_pruna_torchaoint4")
And that’s it! Now you can load your pruna-optimized LLM in vLLM, as you use to.
[ ]:
from vllm import LLM, SamplingParams # noqa: I001
llm = LLM(model="path/to/pruna/model", quantization="hqq_pruna_torchaoint4")
6. Run your (v)LLM
[ ]:
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, min_tokens=100, max_tokens=100)
prompt = "Hello, how are you?"
outputs = llm.generate(prompt, sampling_params)
[ ]:
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") # noqa: T201
Wrap Up
Congratulations! You can now serve LLM that benefit from both vLLM
and pruna
optimizations.
You need more control over your LLM serving engine? Check out the other pruna-compatible engines :doc:TriTonserver+pruna </setup/tritonserver.rst>
and :doc:LitServe+pruna </setup/litserve.rst>
.