Pruna x Triton Inference Server
This tutorial describes how to integrate pruna with NVIDIA’s Triton Inference Server. It is based on our example repository. By using Triton to serve models optimized by Pruna, you can achieve lower latency, higher throughput, and reduced memory usage for production-scale AI deployments.
Pruna + Triton Workflow
Optimize the model with Pruna: Apply Pruna’s techniques (quantization, pruning, compilation, caching, etc.) to reduce the model’s footprint and improve inference speed.
Configure model for Triton: Set up a model repository that includes your optimized model along with a
config.pbtxt
and any Python scripts required to handle inference logic.Deploy and test: Build a Docker image that bundles your optimized model, pruna, and Triton. Run the Triton container and test inference through the Python
tritonclient
.
Example: Serving Stable Diffusion with Pruna + Triton
Below is an example that uses Stable Diffusion, but these steps generalize to other generative or standard neural network models.
Step 1: Prerequisites
Docker Required to run the Triton Inference Server in a containerized environment.
Python 3.9+ Needed to install and run pruna and the Triton client libraries.
Triton Client Library Install via:
pip install tritonclient[grpc]
Step 2: Build the Docker Image
Create a Dockerfile
that starts from NVIDIA’s Triton base image and includes pruna
(pruna[gpu]
), PyTorch, and any other libraries you need. An example snippet:
FROM nvcr.io/nvidia/tritonserver:23.12-py3
ENV DEBIAN_FRONTEND=noninteractive
ENV TZ=Etc/UTC
# Install system dependencies
RUN apt-get update && \
apt-get install -y wget curl git vim sudo cmake build-essential \
libssl-dev libffi-dev python3-dev python3-venv python3-pip libsndfile1 && \
rm -rf /var/lib/apt/lists/*
# Install Python packages
RUN pip3 install packaging psutil pexpect ipywidgets jupyterlab ipykernel \
librosa soundfile
# Upgrade pip
RUN pip3 install --upgrade pip
# Install Pruna
RUN pip3 install pruna==0.2.0 # or pip3 install pruna_pro==0.2.0 depending on the algorithms used
Build the Docker image:
docker build -t tritonserver_pruna .
Step 3: Triton Model Repository Structure
Triton requires a specific structure for serving models (the “model repository”):
model_repository/
└── stable_diffusion/
├── config.pbtxt
└── 1/
└── model.py
stable_diffusion: The directory for the model you want to serve.
config.pbtxt: Defines the input/output format, backend, batching, GPU configuration, etc.
model.py: Contains the Python code for loading and running the Stable Diffusion model (using pruna for optimization).
Adjust naming and structure as needed for your models.
Step 4: Configure config.pbtxt
In the case of Stable Diffusion, your config.pbtxt
might look like this:
Inputs: A string input for the text prompt.
Outputs: An image of shape (512, 512, 3).
Max batch size: e.g., 4
Backend:
python
Example:
name: "stable_diffusion"
backend: "python"
max_batch_size: 4
input [
{
name: "INPUT_TEXT"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
output [
{
name: "OUTPUT"
data_type: TYPE_FP32
dims: [ 3, 512, 512 ]
}
]
instance_group [{ kind: KIND_GPU }]
Step 5: Implement Inference in model.py
Within your model.py
, you’ll:
Load the Stable Diffusion pipeline (or your model of choice).
Apply Pruna optimizations (e.g., step caching or other advanced techniques). Don’t forget to pass your token to the
smash()
function.Define Triton’s execute() function that takes input, runs inference, and returns the result in the correct format.
Example:
import io
import json
import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from diffusers import DDIMScheduler, StableDiffusionPipeline
from PIL import Image
class TritonPythonModel:
def initialize(self, args):
"""Called once when the model is being loaded."""
# Load the Stable Diffusion pipeline
self.pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
)
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
self.pipe = self.pipe.to("cuda")
from pruna import SmashConfig
# Initialize the SmashConfig
smash_config = SmashConfig()
smash_config['cacher'] = 'deepcache'
smash_config['deepcache_interval'] = 3
from pruna import smash
# Smash the model
self.smashed_model = smash(
model=self.pipe,
smash_config=smash_config,
)
# Parse model configuration
self.model_config = json.loads(args["model_config"])
# Get output data type
output_config = pb_utils.get_output_config_by_name(self.model_config, "OUTPUT")
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
def execute(self, requests):
"""Called for inference requests."""
responses = []
for request in requests:
# Get input text
input_text_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT_TEXT")
input_texts = input_text_tensor.as_numpy().astype(str).flatten().tolist()
# Generate images
generated_images = []
for text in input_texts:
# Generate image using the pipeline
image = self.pipe(text).images[0]
# Convert the PIL image to bytes (e.g., PNG format)
buffer = io.BytesIO()
image.save(buffer, format="PNG")
buffer.seek(0)
generated_images.append(buffer.getvalue())
# Convert the list of images to a numpy array of bytes
output_array = np.array(generated_images, dtype=np.object_)
# Create Triton output tensor
output_tensor = pb_utils.Tensor("OUTPUT", output_array)
responses.append(pb_utils.InferenceResponse(output_tensors=[output_tensor]))
return responses
def finalize(self):
"""Called when the model is being unloaded."""
print("Cleaning up...")
Step 6: Run the Triton Server
With your model repository prepared, start Triton using your built Docker image:
docker run --rm --gpus=all \
-p 8000:8000 -p 8001:8001 -p 8002:8002 \
-v "/absolute/path/to/model_repository:/models" \
tritonserver_pruna tritonserver --model-repository=/models
Parameter |
Meaning |
---|---|
|
Remove the container when it stops. |
|
Exposes all available GPUs to the container. |
|
Exposes ports for HTTP/gRPC and model control APIs. |
|
Mounts your model repository into the container at |
|
The Docker image built with Pruna and Triton. |
|
Instructs Triton to load models from |
Step 7: Test Inference with a Python Client
Use the tritonclient
Python library to send requests. Below is a simple example:
# Connect to Triton Server
client = InferenceServerClient(url="localhost:8001")
# Prepare the input
input_text = np.array(["a serene mountain view"], dtype=object).reshape(-1, 1)
input_tensor = InferInput("INPUT_TEXT", input_text.shape, "BYTES")
input_tensor.set_data_from_numpy(input_text)
# Perform inference
response = client.infer(model_name="stable_diffusion", inputs=[input_tensor])
output_data = response.as_numpy("OUTPUT")
print("Generated image shape:", output_data.shape)