Using Pruna with Replicate
What is Replicate?
Deep learning models are getting bigger and more powerful, but making them run efficiently can be tricky. That’s where Replicate comes in—it’s a great platform for running machine learning models in production. But even with its awesome capabilities, unoptimized models can rack up costs, slow down inference, and waste resources.
That’s why we’re excited to introduce you to Pruna—a smart, easy-to-use package designed to make your models faster, leaner, and more efficient.
In this guide, we’ll walk you through how to use Pruna to optimize your models and deploy them on Replicate. Whether you’re a seasoned ML engineer or just starting out, this guide will help you boost performance and get the most out of your models.
Why Optimize Your Models?
When it comes to deploying machine learning models, optimization is key to ensuring scalability and cost-effectiveness. Here’s why you should care:
Faster Inference Times: Optimized models run quicker, leading to better user experiences and higher user retention as a result.
Smaller Models: Optimized models are smaller which allows you to either use a smaller GPU to run the same model or to run multiple models on the same GPU.
Lower Computational Costs: Faster inference times and smaller models allow you to save costs by paying less for the same number of inferences handled and switching to smaller GPUs that cost less respectively.
Environmental Impact: Smaller and faster models consume less energy, making AI more sustainable.
Traditionally, model optimization required deep technical knowledge and significant effort. Pruna simplifies this by automating advanced techniques like quantization, pruning, compilation, and caching, tailored to your specific use case.
Getting Started with Pruna & Replicate
Step 1: Install Pruna
To use Pruna with Replicate, you’ll need Python ≥3.9 and any Nvidia GPU from Replicate. Replicate uses the cog framework for containerizing and deploying models. To integrate Pruna, you’ll need to update your cog.yaml file. Here’s an example configuration:
build:
gpu: true
cuda: "12.1"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
- "git"
- "build-essential"
python_version: "3.11"
run:
- command: pip install pruna[gpu]==0.1.3 --extra-index-url https://prunaai.pythonanywhere.com/
- command: pip install colorama
- command: export CC=/usr/bin/gcc
predict: "predict.py:Predictor"
This setup ensures that Pruna is available during the build process and integrates seamlessly with your model code.
Step 2: Optimize Your Model
Pruna offers powerful optimization techniques in a very simple and easy to use way. You can check out the rest of the documentation for a developer-friendly guide on how to use Pruna. In this guide, we will show you an example using the Flux Schnell model that is mostly based on this tutorial.
Load Your Model
Start by loading your model using FluxPipeline. This will serve as the baseline model before optimization.
from diffusers import FluxPipeline
import torch
# Load the model
self.pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
token="your_hugging_face_token"
).to("cuda")
Configure Pruna for Optimization
Define a SmashConfig object that specifies how the model should be optimized. Pruna allows you to customize parameters like caching and compilation.
from pruna import SmashConfig, smash
# Configure Pruna Smash
smash_config = SmashConfig()
smash_config['compilers'] = ['flux_caching']
smash_config['comp_flux_caching_cache_interval'] = 2
smash_config['comp_flux_caching_start_step'] = 0
Optimize the Model
Pass your model and configuration to Pruna’s smash() function, which applies the optimizations.
# Optimize the model
self.pipe = smash(
model=self.pipe,
token='<your_token>', # replace <your-token> with your actual token or set to None if you do not have one yet
smash_config=smash_config,
)
Use the Optimized Model
After optimization, the model is ready for prediction. For example, you can adjust caching parameters dynamically before generating outputs.
# Generate output
image = self.pipe(
prompt="Your prompt here",
num_inference_steps=4,
).images[0]
Full Code Example
Below is the complete implementation combining all the steps above into a single Predictor class:
import tempfile
import torch
from cog import BasePredictor, Input, Path
from diffusers import FluxPipeline
from pruna import SmashConfig, smash
class Predictor(BasePredictor):
def setup(self) -> None:
"""Load and optimize the model"""
# Load the model
self.pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
token="your_hugging_face_token"
).to("cuda")
# Configure Pruna
smash_config = SmashConfig()
smash_config['compilers'] = ['flux_caching']
smash_config['comp_flux_caching_cache_interval'] = 2
smash_config['comp_flux_caching_start_step'] = 0
# Optimize the model
self.pipe = smash(
model=self.pipe,
token="your_pruna_token",
smash_config=smash_config,
)
def predict(
self,
prompt: str = Input(description="Prompt"),
num_inference_steps: int = Input(
description="Number of inference steps", default=4
),
guidance_scale: float = Input(
description="Guidance scale", default=7.5
),
seed: int = Input(description="Seed", default=42),
image_height: int = Input(description="Image height", default=1024),
image_width: int = Input(description="Image width", default=1024),
cache_interval: int = Input(description="Cache interval", default=3),
start_step: int = Input(description="Start step", default=1),
) -> Path:
"""Run a prediction"""
self.pipe.flux_cache_helper.set_params(
cache_interval=cache_interval, start_step=start_step
)
image = self.pipe(
prompt,
height=image_height,
width=image_width,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=torch.Generator("cpu").manual_seed(seed)
).images[0]
output_dir = Path(tempfile.mkdtemp())
image_path = output_dir / "output.png"
image.save(image_path)
return image_path
Deploying the Optimized Model on Replicate
Step 1: Setup GitHub Workflow
To streamline the deployment, you can use GitHub Actions to automate pushing models to Replicate. Here’s an example workflow file (push_flux_schnell.yaml):
name: Push Flux Schnell to Replicate
on:
workflow_dispatch:
inputs:
model_name:
default: "prunaai/flux-schnell"
jobs:
push_to_replicate:
name: Push to Replicate
runs-on: ubuntu-latest
steps:
- name: Free disk space
uses: jlumbroso/[email protected]
- name: Checkout
uses: actions/checkout@v4
- name: Setup Cog
uses: replicate/setup-cog@v2
with:
token: ${{ secrets.REPLICATE_API_TOKEN }}
- name: Push to Replicate
run: |
cog push
This workflow automates the deployment of your optimized model, saving time and effort.
Step 2: Push the Optimized Model
Once your workflow is set up, you can simply run the Github Action to deploy your model to Replicate.
Ending notes
Congratulations, you deployed your optimized model to Replicate!
When you combine Pruna’s optimization superpowers with Replicate’s seamless deployment platform, you get the best of both worlds—performance, cost savings, and scalability all in one!
In just a few steps, you can make your models run faster, use fewer resources, and deliver even better results. Curious to see it in action? Check out the models we’ve deployed using Pruna and Replicate on our Pruna Replicate repository.
Excited to give it a shot? Download Pruna, integrate it into your workflow, and deploy your optimized models on Replicate. If you have any questions, our team is here to support you! Join the conversation and ask us anything on Discord.