Pruna x Triton Inference Server

This tutorial describes how to integrate pruna with NVIDIA’s Triton Inference Server. It is based on our example repository. By using Triton to serve models optimized by Pruna, you can achieve lower latency, higher throughput, and reduced memory usage for production-scale AI deployments.

Pruna + Triton Workflow

  1. Optimize the model with Pruna: Apply Pruna’s techniques (quantization, pruning, compilation, caching, etc.) to reduce the model’s footprint and improve inference speed.

  2. Configure model for Triton: Set up a model repository that includes your optimized model along with a config.pbtxt and any Python scripts required to handle inference logic.

  3. Deploy and test: Build a Docker image that bundles your optimized model, pruna, and Triton. Run the Triton container and test inference through the Python tritonclient.

Example: Serving Stable Diffusion with Pruna + Triton

Below is an example that uses Stable Diffusion, but these steps generalize to other generative or standard neural network models.

Step 1: Prerequisites

  • Docker Required to run the Triton Inference Server in a containerized environment.

  • Python 3.9+ Needed to install and run pruna and the Triton client libraries.

  • Triton Client Library Install via:

    pip install tritonclient[grpc]
    

Step 2: Build the Docker Image

Create a Dockerfile that starts from NVIDIA’s Triton base image and includes pruna (pruna[gpu]), PyTorch, and any other libraries you need. An example snippet:

FROM nvcr.io/nvidia/tritonserver:23.12-py3

ENV DEBIAN_FRONTEND=noninteractive
ENV TZ=Etc/UTC

# Install system dependencies
RUN apt-get update && \
    apt-get install -y wget curl git vim sudo cmake build-essential \
    libssl-dev libffi-dev python3-dev python3-venv python3-pip libsndfile1 && \
    rm -rf /var/lib/apt/lists/*

# Install Python packages
RUN pip3 install packaging psutil pexpect ipywidgets jupyterlab ipykernel \
    librosa soundfile

# Upgrade pip
RUN pip3 install --upgrade pip

# Install Pruna
RUN pip3 install pruna==0.2.0  # or pip3 install pruna_pro==0.2.0 depending on the algorithms used

Build the Docker image:

docker build -t tritonserver_pruna .

Step 3: Triton Model Repository Structure

Triton requires a specific structure for serving models (the “model repository”):

model_repository/
└── stable_diffusion/
    ├── config.pbtxt
    └── 1/
        └── model.py
  • stable_diffusion: The directory for the model you want to serve.

  • config.pbtxt: Defines the input/output format, backend, batching, GPU configuration, etc.

  • model.py: Contains the Python code for loading and running the Stable Diffusion model (using pruna for optimization).

Adjust naming and structure as needed for your models.

Step 4: Configure config.pbtxt

In the case of Stable Diffusion, your config.pbtxt might look like this:

  • Inputs: A string input for the text prompt.

  • Outputs: An image of shape (512, 512, 3).

  • Max batch size: e.g., 4

  • Backend: python

Example:

name: "stable_diffusion"
backend: "python"
max_batch_size: 4
input [
  {
    name: "INPUT_TEXT"
    data_type: TYPE_STRING
    dims: [ 1 ]
  }
]
output [
  {
    name: "OUTPUT"
    data_type: TYPE_FP32
    dims: [ 3, 512, 512 ]
  }
]
instance_group [{ kind: KIND_GPU }]

Step 5: Implement Inference in model.py

Within your model.py, you’ll:

  1. Load the Stable Diffusion pipeline (or your model of choice).

  2. Apply Pruna optimizations (e.g., step caching or other advanced techniques). Don’t forget to pass your token to the smash() function.

  3. Define Triton’s execute() function that takes input, runs inference, and returns the result in the correct format.

Example:

import io
import json

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from diffusers import DDIMScheduler, StableDiffusionPipeline
from PIL import Image


class TritonPythonModel:
    def initialize(self, args):
        """Called once when the model is being loaded."""
        # Load the Stable Diffusion pipeline
        self.pipe = StableDiffusionPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
        )
        self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
        self.pipe = self.pipe.to("cuda")

        from pruna import SmashConfig

        # Initialize the SmashConfig
        smash_config = SmashConfig()
        smash_config['cacher'] = 'deepcache'
        smash_config['deepcache_interval'] = 3

        from pruna import smash

        # Smash the model
        self.smashed_model = smash(
            model=self.pipe,
            smash_config=smash_config,
        )

        # Parse model configuration
        self.model_config = json.loads(args["model_config"])

        # Get output data type
        output_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT")
        self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])

    def execute(self, requests):
        """Called for inference requests."""
        responses = []
        for request in requests:
            # Get input text
            input_text_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT")
            input_texts = input_text_tensor.as_numpy().astype(str).flatten().tolist()

            # Generate images
            generated_images = []
            for text in input_texts:
                # Generate image using the pipeline
                image = self.pipe(text).images[0]

                # Convert the PIL image to bytes (e.g., PNG format)
                buffer = io.BytesIO()
                image.save(buffer, format="PNG")
                buffer.seek(0)
                generated_images.append(buffer.getvalue())

            # Convert the list of images to a numpy array of bytes
            output_array = np.array(generated_images, dtype=np.object_)

            # Create Triton output tensor
            output_tensor = pb_utils.Tensor("OUTPUT", output_array)
            responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))

        return responses

    def finalize(self):
        """Called when the model is being unloaded."""
        print("Cleaning up...")

Step 6: Run the Triton Server

With your model repository prepared, start Triton using your built Docker image:

docker run --rm --gpus=all \
    -p 8000:8000 -p 8001:8001 -p 8002:8002 \
    -v "/absolute/path/to/model_repository:/models" \
    tritonserver_pruna tritonserver --model-repository=/models

Parameter

Meaning

--rm

Remove the container when it stops.

--gpus=all

Exposes all available GPUs to the container.

-p 8000:8000, etc.

Exposes ports for HTTP/gRPC and model control APIs.

-v /models:/models

Mounts your model repository into the container at /models. Replace the path with your own.

tritonserver_pruna

The Docker image built with Pruna and Triton.

--model-repository=/models

Instructs Triton to load models from /models.

Step 7: Test Inference with a Python Client

Use the tritonclient Python library to send requests. Below is a simple example:

 # Connect to Triton Server
 client = InferenceServerClient(url="localhost:8001")

 # Prepare the input
 input_text = np.array(["a serene mountain view"], dtype=object).reshape(-1, 1)
 input_tensor = InferInput("INPUT_TEXT", input_text.shape, "BYTES")
 input_tensor.set_data_from_numpy(input_text)

 # Perform inference
 response = client.infer(model_name="stable_diffusion", inputs=[input_tensor])
 output_data = response.as_numpy("OUTPUT")

 print("Generated image shape:", output_data.shape)

Additional Resources