Lightning - LitServe

Getting Started with Pruna & LitServe

In this guide, you will learn how to use LitServe to serve your pruna models. LitServe is a flexible serving engine for AI models built on FastAPI. Features like batching, streaming, and GPU autoscaling eliminate the need to rebuild a FastAPI server per model.

Step 1. Installation

To use LitServe, you’ll need to make sure you can install pruna and its dependencies. You can take a look at the installation guide or Dockerfile for more information.

LitServe can simpy be installed with the following command:

pip install litserve

Now, let’s see how to use LitServe to serve your pruna models.

Step 2. Understand the Basics of LitServe

LitServe is a deployment framework that helps you serve and deploy any AI models, Lightning fast. So this makes it perfect for serving your pruna models but in order to understand how to use it, you need to understand the basics of LitServe. To get started, you can take a look at the LitServe documentation.

For creating a LitAPI, you simply define a class that inherits from ls.LitAPI and implements the setup, decode_request, predict, and encode_response methods to process the request and response. Underneath you can find a minimal example of how to serve an arbitrary model with LitServe. Create a file called server.py and add the following code:

import litserve as ls

 class SimpleLitAPI(ls.LitAPI):
     def setup(self, device):
         self.model1 = lambda x: x**2
         self.model2 = lambda x: x**3

     def decode_request(self, request):
         return request["input"]

     def predict(self, x):
         squared = self.model1(x)
         cubed = self.model2(x)
         output = squared + cubed
         return {"output": output}

     def encode_response(self, output):
         return {"output": output}

 if __name__ == "__main__":
     api = SimpleLitAPI()
     server = ls.LitServer(api, accelerator="gpu")
     server.run(port=8000)

You can then simply start the server with the following command:

lightning deploy server.py

This is a simple example of how to serve an arbitrary model with LitServe but there are many more configurations options, like batching. In the LitServe documentation you can find more examples.

Step 3. Self-Host Compressed AI Models with LitServe

In this step, we will see how to deploy some of the most popular models with LitServe and pruna.

This is an example of how to serve an optimized Stable Diffusion model with LitServe and pruna, although the same workflow can be applied to any text-to-image model, like FLUX. You can find the full code in the LitServe documentation.

First, we define the LitAPI. After, we define the SmashConfig to optimize the model for inference. In this case, we will be optimising Segmind-Vega, a distilled version of Stable Diffusion XL, using the deepcache cacher, torch_compile compiler, and hqq_diffusers quantizer.

import torch, base64
from diffusers import DiffusionPipeline
from io import BytesIO
import litserve as ls
from pruna import SmashConfig, smash

class SimpleLitAPI(ls.LitAPI):
    def setup(self, device):
        # Load the model and tokenizer
        self.device = device
        self.model = DiffusionPipeline.from_pretrained(
            "segmind/Segmind-Vega",
            torch_dtype=torch.float16,
        ).to(device)

        # Configure the SmashConfig to optimize the model for inference
        smash_config = SmashConfig(device=device)
        smash_config["cacher"] = "deepcache"
        smash_config["deepcache_interval"] = 2
        smash_config["compiler"] = "torch_compile"
        smash_config["quantizer"] = "hqq_diffusers"
        smash_config["hqq_diffusers_weight_bits"] = 4
        smash_config["hqq_diffusers_group_size"] = 64
        smash_config["hqq_diffusers_backend"] = "marlin"
        self.smashed_model = smash(
            model=self.model,
            smash_config=smash_config,
        )

    def decode_request(self, request):
        # Extract prompt from request
        prompt = request["prompt"]
        return prompt

    def predict(self, prompt):
        # Generate image from prompt
        with torch.no_grad():
            # Adjusted to directly access the generated image from the output without using the 'sample' key.
            # The output from the model is expected to be a list of PIL images.
            images = self.smashed_model(
                prompt,
                num_inference_steps=28,
                guidance_scale=7.5,
            )["images"]
            image = images[0]  # Assuming you want to retrieve the first image
        return image

    def encode_response(self, image):
        # Convert the generated PIL Image to a Base64 string
        buffered = BytesIO()
        image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")

        return {"image": img_str}

# Starting the server
if __name__ == "__main__":
    # Assume that an appropriate device (e.g., 'cuda', 'cpu') is specified
    api = SimpleLitAPI()
    server = ls.LitServer(api)
    server.run(port=8000)

After starting the server, you can test it with the following code:

import requests
import base64
from io import BytesIO
from PIL import Image

data = {"prompt": "A beautiful sunset over a calm ocean"}

# Send a request to your LitServe server
response = requests.post("http://localhost:8000/predict", json=data)

# Get the Base64-encoded image string from the response
img_str = response.json().get("image")

if img_str:
    # Decode the Base64 string to bytes
    img_bytes = base64.b64decode(img_str)

    # Convert bytes data to PIL Image
    img = Image.open(BytesIO(img_bytes))

    # Save the image
    img.save("generated_image.png")

Ending Notes

In this guide, we have seen how to use LitServe to serve your pruna models. LitServe is a powerful tool that can help you deploy your models quickly and easily. You can even deploy your models to Lightning AI Hub. If you have any questions or feedback, please don’t hesitate to contact us.