Pruna x Replicate

Getting Started with Pruna & Replicate

Replicate is a great platform for running machine learning models in production. However, unoptimized models can rack up costs, slow down inference, and waste resources. In this guide, we’ll walk you through how to use pruna to optimize your models and deploy them on Replicate. While in this guide we will show you how to supercharge your Flux model with pruna_pro, the same workflow applies to using pruna - simply adjust the installation command and the compression configuration.

Step 1: Install Pruna

To use pruna with Replicate, you’ll need Python ≥3.9 and any Nvidia GPU from Replicate. Replicate uses the cog framework for containerizing and deploying models. To integrate pruna, you’ll need to update your cog.yaml file. Here’s an example configuration:

build:
  gpu: true
  cuda: "12.4"
  system_packages:
    - "libgl1-mesa-glx"
    - "libglib2.0-0"
    - "git"
    - "build-essential"
  python_version: "3.11"
  run:
    - command: pip install pruna_pro==0.2.0  # or pip install pruna==0.2.0
    - command: pip install colorama
    - command: export CC=/usr/bin/gcc
predict: "predict.py:Predictor"

This setup ensures that pruna is available during the build process and integrates seamlessly with your model code.

Step 2: Optimize Your Model

In this guide, we will show you an example using the Flux Schnell model that is mostly based on this tutorial.

  1. Load Your Model

Start by loading your model using FluxPipeline. This will serve as the baseline model before optimization.

from diffusers import FluxPipeline
import torch

# Load the model
self.pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-schnell",
    torch_dtype=torch.bfloat16,
    token="your_hugging_face_token"
).to("cuda")
  1. Configure Pruna for Optimization

Define a SmashConfig object that specifies how the model should be optimized. pruna allows you to customize parameters like caching and compilation.

from pruna_pro import SmashConfig, smash

# Configure Pruna Smash
smash_config = SmashConfig()
smash_config['cacher'] = 'periodic'
smash_config['periodic_cache_interval'] = 2
smash_config['periodic_start_step'] = 2
  1. Optimize the Model

Pass your model and configuration to Pruna’s smash() function, which applies the optimizations.

# Optimize the model
self.pipe = smash(
    model=self.pipe,
    token='<your_pruna_token>',  # replace <your-token> with your actual token or set to None if you do not have one yet
    smash_config=smash_config,
)
  1. Use the Optimized Model

After optimization, the model is ready for prediction. For example, you can adjust caching parameters dynamically before generating outputs.

# Generate output
image = self.pipe(
    prompt="Your prompt here",
    num_inference_steps=4,
).images[0]

Full Code Example

Below is the complete implementation combining all the steps above into a single Predictor class:

import tempfile
import torch
from cog import BasePredictor, Input, Path
from diffusers import FluxPipeline
from pruna_pro import SmashConfig, smash

class Predictor(BasePredictor):
    def setup(self) -> None:
        """Load and optimize the model"""
        # Load the model
        self.pipe = FluxPipeline.from_pretrained(
            "black-forest-labs/FLUX.1-schnell",
            torch_dtype=torch.bfloat16,
            token="your_hugging_face_token"
        ).to("cuda")

        # Configure Pruna
        smash_config = SmashConfig()
        smash_config['cacher'] = 'periodic'
        smash_config['periodic_cache_interval'] = 2
        smash_config['periodic_start_step'] = 2

        # Optimize the model
        self.pipe = smash(
            model=self.pipe,
            token="your_pruna_token",
            smash_config=smash_config,
        )

    def predict(
        self,
        prompt: str = Input(description="Prompt"),
        num_inference_steps: int = Input(
            description="Number of inference steps", default=4
        ),
        guidance_scale: float = Input(
            description="Guidance scale", default=7.5
        ),
        seed: int = Input(description="Seed", default=42),
        image_height: int = Input(description="Image height", default=1024),
        image_width: int = Input(description="Image width", default=1024),
        cache_interval: int = Input(description="Cache interval", default=3),
        start_step: int = Input(description="Start step", default=1),
    ) -> Path:
        """Run a prediction"""
        image = self.pipe(
            prompt,
            height=image_height,
            width=image_width,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            generator=torch.Generator("cpu").manual_seed(seed)
        ).images[0]

        output_dir = Path(tempfile.mkdtemp())
        image_path = output_dir / "output.png"
        image.save(image_path)
        return image_path

Deploying the Optimized Model on Replicate

Step 1: Setup GitHub Workflow

To streamline the deployment, you can use GitHub Actions to automate pushing models to Replicate. Here’s an example workflow file (push_flux_schnell.yaml):

name: Push Flux Schnell to Replicate
on:
  workflow_dispatch:
    inputs:
      model_name:
        default: "prunaai/flux-schnell"
jobs:
  push_to_replicate:
    name: Push to Replicate
    runs-on: ubuntu-latest
    steps:
      - name: Free disk space
        uses: jlumbroso/[email protected]
      - name: Checkout
        uses: actions/checkout@v4
      - name: Setup Cog
        uses: replicate/setup-cog@v2
        with:
          token: ${{ secrets.REPLICATE_API_TOKEN }}
      - name: Push to Replicate
        run: |
          cog push

This workflow automates the deployment of your optimized model, saving time and effort.

Step 2: Push the Optimized Model

Once your workflow is set up, you can simply run the Github Action to deploy your model to Replicate.

Ending notes

Congratulations, you deployed your optimized model to Replicate!

When you combine Pruna’s optimization superpowers with Replicate’s seamless deployment platform, you get the best of both worlds—performance, cost savings, and scalability all in one!