Accelerating inference in vLLM serving (Pro)

Open In Colab

This tutorial demonstrates how to use the pruna_pro package to optimize any LLM, and plug it into vLLM for fast serving.

1. Loading the LLM

First, load your language model, and its tokenizer.

[ ]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_path = "meta-llama/Llama-3.1-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(model_path,
                                             torch_dtype=torch.bfloat16,
                                             device_map="auto",
                                             low_cpu_mem_usage=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)

2. Initializing the Smash Config

Next, initialize the smash_config.

[ ]:
from pruna_pro import SmashConfig, smash

# Initialize the SmashConfig
smash_config = SmashConfig()
# Select the quantizer
smash_config['quantizer'] = 'hqq'
smash_config['hqq_weight_bits'] = 4
smash_config['hqq_compute_dtype'] = 'torch.bfloat16'  # can work with float16, but better performance with bfloat16
smash_config['hqq_use_torchao_kernels'] = False  # set to False to enable saving
smash_config['hqq_force_hf_implementation'] = True  # set to True to bypass the AutoHQQHFModel quantization

3. Smashing the Model

Now, you can smash the model, which can take up to 2 minutes. Don’t forget to replace the token by the one provided by PrunaAI.

[ ]:
model = smash(
    model=model,
    token="<your_pruna_token>",
    smash_config=smash_config,
)

4. Saving the model

vLLM can not (yet) accept already loaded models. So you have to save your pruna-optimized model to some “”, and then give this path to vLLM.

[ ]:
smash_config.save_fns = []  # remove the save functions because vLLM expects a model saved with hf from_pretrained fn
model.save_pretrained("path/to/pruna/model")
tokenizer.save_pretrained("path/to/pruna/model")

5. Loading your optimized model with vLLM

Assuming you have pruna_pro installed in your env (pip install pruna_pro), you basicaly have only one step to add before serving your LLM with pruna_pro: we just need to change the “quant_method” name from “hqq” to “hqq_pruna_torchaoint4” (recognized and compatible with vLLM).

[ ]:
from pruna_pro.algorithms.quantization.utils.pruna_to_vllm import patch_quantization_name

patch_quantization_name("path/to/pruna/model", "hqq_pruna_torchaoint4")

And that’s it! Now you can load your pruna-optimized LLM in vLLM, as you use to.

[ ]:
from vllm import LLM, SamplingParams  # noqa: I001
llm =  LLM(model="path/to/pruna/model", quantization="hqq_pruna_torchaoint4")

6. Run your (v)LLM

[ ]:
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, min_tokens=100, max_tokens=100)
prompt = "Hello, how are you?"
outputs = llm.generate(prompt, sampling_params)
[ ]:
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")  # noqa: T201

Wrap Up

Congratulations! You can now serve LLM that benefit from both vLLM and pruna optimizations.

You need more control over your LLM serving engine? Check out the other pruna-compatible engines :doc:TriTonserver+pruna </setup/tritonserver.rst> and :doc:LitServe+pruna </setup/litserve.rst>.