Lightning - LitServe
Getting Started with Pruna & LitServe
In this guide, you will learn how to use LitServe to serve your pruna models. LitServe is a flexible serving engine for AI models built on FastAPI. Features like batching, streaming, and GPU autoscaling eliminate the need to rebuild a FastAPI server per model.
Step 1. Installation
To use LitServe, you’ll need to make sure you can install pruna and its dependencies. You can take a look at the installation guide or Dockerfile for more information.
LitServe can simpy be installed with the following command:
pip install litserve
Now, let’s see how to use LitServe to serve your pruna models.
Step 2. Understand the Basics of LitServe
LitServe is a deployment framework that helps you serve and deploy any AI models, Lightning fast. So this makes it perfect for serving your pruna models but in order to understand how to use it, you need to understand the basics of LitServe. To get started, you can take a look at the LitServe documentation.
For creating a LitAPI, you simply define a class that inherits from ls.LitAPI
and implements the setup
, decode_request
, predict
, and encode_response
methods to process the request and response.
Underneath you can find a minimal example of how to serve an arbitrary model with LitServe. Create a file called server.py
and add the following code:
import litserve as ls
class SimpleLitAPI(ls.LitAPI):
def setup(self, device):
self.model1 = lambda x: x**2
self.model2 = lambda x: x**3
def decode_request(self, request):
return request["input"]
def predict(self, x):
squared = self.model1(x)
cubed = self.model2(x)
output = squared + cubed
return {"output": output}
def encode_response(self, output):
return {"output": output}
if __name__ == "__main__":
api = SimpleLitAPI()
server = ls.LitServer(api, accelerator="gpu")
server.run(port=8000)
You can then simply start the server with the following command:
lightning deploy server.py
lightning deploy server.py --cloud
This is a simple example of how to serve an arbitrary model with LitServe but there are many more configurations options, like batching. In the LitServe documentation you can find more examples.
Step 3. Self-Host Compressed AI Models with LitServe
In this step, we will see how to deploy some of the most popular models with LitServe and pruna.
This is an example of how to serve an optimized Stable Diffusion model with LitServe and pruna, although the same workflow can be applied to any text-to-image model, like FLUX. You can find the full code in the LitServe documentation.
First, we define the LitAPI. After, we define the SmashConfig
to optimize the model for inference.
In this case, we will be optimising Segmind-Vega, a distilled version of Stable Diffusion XL, using the deepcache
cacher, torch_compile
compiler, and hqq_diffusers
quantizer.
import torch, base64
from diffusers import DiffusionPipeline
from io import BytesIO
import litserve as ls
from pruna import SmashConfig, smash
class SimpleLitAPI(ls.LitAPI):
def setup(self, device):
# Load the model and tokenizer
self.device = device
self.model = DiffusionPipeline.from_pretrained(
"segmind/Segmind-Vega",
torch_dtype=torch.float16,
).to(device)
# Configure the SmashConfig to optimize the model for inference
smash_config = SmashConfig(device=device)
smash_config["cacher"] = "deepcache"
smash_config["deepcache_interval"] = 2
smash_config["compiler"] = "torch_compile"
smash_config["quantizer"] = "hqq_diffusers"
smash_config["hqq_diffusers_weight_bits"] = 4
smash_config["hqq_diffusers_group_size"] = 64
smash_config["hqq_diffusers_backend"] = "marlin"
self.smashed_model = smash(
model=self.model,
smash_config=smash_config,
)
def decode_request(self, request):
# Extract prompt from request
prompt = request["prompt"]
return prompt
def predict(self, prompt):
# Generate image from prompt
with torch.no_grad():
# Adjusted to directly access the generated image from the output without using the 'sample' key.
# The output from the model is expected to be a list of PIL images.
images = self.smashed_model(
prompt,
num_inference_steps=28,
guidance_scale=7.5,
)["images"]
image = images[0] # Assuming you want to retrieve the first image
return image
def encode_response(self, image):
# Convert the generated PIL Image to a Base64 string
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"image": img_str}
# Starting the server
if __name__ == "__main__":
# Assume that an appropriate device (e.g., 'cuda', 'cpu') is specified
api = SimpleLitAPI()
server = ls.LitServer(api)
server.run(port=8000)
After starting the server, you can test it with the following code:
import requests
import base64
from io import BytesIO
from PIL import Image
data = {"prompt": "A beautiful sunset over a calm ocean"}
# Send a request to your LitServe server
response = requests.post("http://localhost:8000/predict", json=data)
# Get the Base64-encoded image string from the response
img_str = response.json().get("image")
if img_str:
# Decode the Base64 string to bytes
img_bytes = base64.b64decode(img_str)
# Convert bytes data to PIL Image
img = Image.open(BytesIO(img_bytes))
# Save the image
img.save("generated_image.png")
This is an example of how to serve an optimized HuggingFaceTB/SmolLM2-1.7B-Instruct model with LitServe and pruna. You can find the full code in the LitServe documentation.
First, we define the LitAPI. After, we define the SmashConfig
to optimize the model for inference.
In this case, we will be optimizing HuggingFaceTB/SmolLM2-1.7B-Instruct using the torch_compile
compiler, and hqq
quantizer.
import litgpt
import litserve as ls
from pruna import SmashConfig, smash
from transformers import AutoTokenizer
class SimpleLitAPI(ls.LitAPI):
def setup(self, device):
# Load the model
model_name = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
pipe = pipeline(
task="text-generation",
model=model_name,
)
# Configure the SmashConfig to optimize the model for inference
smash_config = SmashConfig(device=device)
smash_config["quantizer"] = "hqq"
smash_config["hqq_weight_bits"] = 4
smash_config["hqq_compute_dtype"] = "torch.bfloat16"
smash_config["compiler"] = "torch_compile"
smash_config["torch_compile_fullgraph"] = True
smash_config["torch_compile_dynamic"] = True
smash_config["torch_compile_mode"] = "max-autotune"
# Smash the model
self.smashed_model = smash(
model=pipe.model,
smash_config=smash_config,
)
def decode_request(self, request):
return request["prompt"]
def predict(self, prompt):
return self.smashed_model(prompt, max_new_tokens=200)
def encode_response(self, output):
return {"output": output}
if __name__ == "__main__":
api = SimpleLitAPI()
server = ls.LitServer(api)
server.run(port=8000)
After starting the server, you can test it with the following code:
import requests
data = {"prompt": "What is the capital of France?"}
response = requests.post("http://localhost:8000/predict", json=data)
print(response.json())
This is an example of how to serve an optimized Whisper model with LitServe and pruna. You can find the full code in the LitServe documentation.
First, we define the LitAPI. After, we define the SmashConfig
to optimize the model for inference.
In this case, we will be optimizing Whisper using the c_whisper
compiler and whisper_s2t
batcher.
import torch
import litserve as ls
from transformers import AutoModelForSpeechSeq2Seq
from pruna import SmashConfig, smash
class SimpleLitAPI(ls.LitAPI):
def setup(self, device):
# Load the OpenAI Whisper model. You can specify other models like "base", "small", etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
torch_dtype=torch_dtype,
use_safetensors=True,
low_cpu_mem_usage=True,
)
model.to(device)
# Configure the SmashConfig to optimize the model for inference
smash_config = SmashConfig(device=device)
smash_config.add_tokenizer(model_id)
smash_config.add_processor(model_id)
smash_config['compiler'] = 'c_whisper'
smash_config['batcher'] = 'whisper_s2t'
self.smashed_model = smash(
model=model,
smash_config=smash_config,
)
def decode_request(self, request):
# Assuming the request sends the path to the audio file
# In a more robust implementation, you would handle audio data directly.
return request["audio_path"]
def predict(self, audio_path):
# Process the audio file and return the transcription result
return self.smashed_model(audio_path)
def encode_response(self, output):
# Return the transcription text
return {"transcription": output["text"]}
if __name__ == "__main__":
api = SimpleLitAPI()
server = ls.LitServer(api, accelerator="gpu", timeout=1000, workers_per_device=2)
server.run(port=8000)
After starting the server, you can test it with the following code:
import requests
import os
from pathlib import Path
audio_sample = Path('sam_altman_lex_podcast_367.flac')
if not audio_sample.exists():
response = requests.get(
"https://huggingface.co/datasets/reach-vb/random-audios/resolve/main/sam_altman_lex_podcast_367.flac",
stream=True
)
response.raise_for_status()
audio_sample.write_bytes(response.content)
url = "http://127.0.0.1:8000/predict"
response = requests.post(url, json={"audio_path": "path/to/audio.wav"})
print(response.json())
Ending Notes
In this guide, we have seen how to use LitServe to serve your pruna models. LitServe is a powerful tool that can help you deploy your models quickly and easily. You can even deploy your models to Lightning AI Hub. If you have any questions or feedback, please don’t hesitate to contact us.