Pruna x Koyeb

Getting Started with Pruna & Koyeb

Koyeb is a serverless platform for seamlessly deploying AI applications and databases on high-performance infrastructure, including CPUs, GPUs, and Accelerators around the world.

However, unoptimized models can rack up costs, slow down inference, and waste resources. In this guide, we’ll walk you through how to use pruna to optimize your models and deploy them on Koyeb. While in this guide we will show you how to supercharge your Flux model with pruna_pro, the same workflow applies to using pruna - simply adjust the installation command and the compression configuration.

Deploy to Koyeb

Requirements

  • You have created a Koyeb account and have access to the platform.

  • You have installed the Koyeb Command Line Interface (CLI) on your development machine.

  • You have Python (version 3.9 or higher) installed on your local development environment.

Step 1: Install Pruna

To use pruna with Koyeb, you’ll need Python ≥3.9 and any Nvidia GPU from Koyeb. We will use uv to install and manage our project dependencies. For the purpose of this demo, we will use FastAPI to build the application and serve the model and perform predictions.

Get started by initializing a new project using uv:

uv init pruna-on-koyeb

Next, install the dependencies that will be required by our application:

uv add fastapi diffusers torch pruna_pro

Then, create a new file server.py containing the following complete implementation of our application. We will in the next section breakdown the different steps used to optimize the model using Pruna.

import base64
import io
import logging
import os
import time
from contextlib import asynccontextmanager
from typing import List, Literal, Optional

import torch
from diffusers import FluxPipeline
from fastapi import FastAPI, HTTPException
from pruna_pro import SmashConfig, smash
from pydantic import BaseModel, Field, field_validator

os.environ["TOKENIZERS_PARALLELISM"] = "false"

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "black-forest-labs/FLUX.1-dev"

class ModelManager:
    """Manages the loading, unloading and access to the Flux model pipeline."""

    def __init__(self):
        self.device = DEVICE

    async def load_model(self):
        logger.info(f"Loading model {MODEL_ID} on device {self.device}...")
        try:
            base_pipe = FluxPipeline.from_pretrained(
                "black-forest-labs/FLUX.1-dev",
                torch_dtype=torch.bfloat16,
            ).to(self.device)

            smash_config = SmashConfig()
            smash_config["cacher"] = "taylor_auto"
            smash_config["compiler"] = "torch_compile"
            smash_config._prepare_saving = False

            self.pipe = smash(
                model=base_pipe,
                token=os.getenv("PRUNA_API_KEY"),
                smash_config=smash_config,
            )

            logger.info("Model loaded successfully")

        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            raise RuntimeError(f"Failed to load model: {str(e)}")

    async def unload_model(self):
        """Cleanup method to properly unload models"""
        try:
            del self.pipe

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            logger.info("Model components unloaded successfully")
        except Exception as e:
            logger.error(f"Error unloading model components: {str(e)}")

model_manager = ModelManager()

class GenerationRequest(BaseModel):
    """Request model for image generation containing all parameters."""

    prompt: str = Field(..., min_length=1, max_length=1000)
    prompt_2: Optional[str] = Field(None, min_length=1, max_length=1000)
    height: Optional[int] = Field(1024, ge=64, le=2048)
    width: Optional[int] = Field(1024, ge=64, le=2048)
    num_inference_steps: int = Field(50, ge=1, le=100)
    guidance_scale: float = Field(3.5, ge=0.0, le=10.0)
    max_sequence_length: int = Field(256, ge=1, le=256)
    num_images_per_prompt: int = Field(1, ge=1, le=5)
    seed: Optional[int] = Field(None)
    speed_mode: Literal[
        "Lightly Juiced 🍊 (more consistent)",
        "Juiced 🔥 (default)",
        "Extra Juiced 🔥 (more speed)",
    ] = Field(default="Juiced 🔥 (default)")

    @field_validator("height", "width")
    @classmethod
    def validate_dimensions(cls, v: Optional[int]) -> Optional[int]:
        if v is not None and v % 8 != 0:
            raise ValueError("Height and width must be divisible by 8")
        return v

    class Config:
        json_schema_extra = {
            "example": {
                "prompt": "A beautiful landscape",
                "height": 1024,
                "width": 1024,
                "num_inference_steps": 30,
            }
        }

class GenerationResponse(BaseModel):
    images: List[str]
    seed: int

def encode_images(images, quality: int = 85):
    """
    Encode PIL images to base64 JPEG strings.

    Args:
        images: List of PIL Image objects
        quality: JPEG quality (1-100)
    Returns:
        List of base64 encoded image strings
    """

    encoded_images = []
    for img in images:
        with io.BytesIO() as buffered:
            img = img.convert("RGB")
            img.save(buffered, format="JPEG", quality=quality, optimize=True)
            img_str = base64.b64encode(buffered.getvalue()).decode()
            encoded_images.append(f"data:image/jpeg;base64,{img_str}")
    return encoded_images

@asynccontextmanager
async def lifespan(_: FastAPI):
    await model_manager.load_model()
    yield
    await model_manager.unload_model()
    logger.info("Application shut down successfully")

app = FastAPI(
    title="flux.1-juiced API",
    description="API for generating images using FLUX.1 [dev] with Pruna AI",
    version="1.0.0",
    lifespan=lifespan,
)

@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "device": model_manager.device,
        "model_loaded": hasattr(model_manager, "pipe"),
    }

@app.post("/predict", response_model=GenerationResponse)
async def predict(request: GenerationRequest):
    try:
        logger.info(f"Starting image generation with prompt: {request.prompt[:50]}...")
        logger.debug(f"Generation parameters: {request.model_dump()}")

        start_time = time.time()

        if request.seed is not None:
            torch.manual_seed(request.seed)
        else:
            request.seed = torch.randint(0, 2**32 - 1, (1,)).item()

        pipeline = model_manager.pipe

        if hasattr(pipeline, "cache_helper"):
            pipeline.cache_helper.disable()
            pipeline.cache_helper.enable()
            if request.speed_mode == "Lightly Juiced 🍊 (more consistent)":
                print("Setting cache speed factor: 0.4")
                pipeline.cache_helper.set_params(
                    speed_factor=0.5 if request.num_inference_steps > 20 else 0.6,
                )
            elif request.speed_mode == "Extra Juiced 🔥 (more speed)":
                print("Setting cache speed factor: 0.2")
                pipeline.cache_helper.set_params(
                    speed_factor=0.3 if request.num_inference_steps > 20 else 0.4,
                )
            elif request.speed_mode == "Juiced 🔥 (default)":
                print("Setting cache speed factor: 0.5")
                pipeline.cache_helper.set_params(
                    speed_factor=0.4 if request.num_inference_steps > 20 else 0.5,
                )
        else:
            print("Warning: Selected pipeline does not have cache_helper.")

        images = pipeline(
            prompt=request.prompt,
            prompt_2=request.prompt_2,
            height=request.height,
            width=request.width,
            num_inference_steps=request.num_inference_steps,
            guidance_scale=request.guidance_scale,
            max_sequence_length=request.max_sequence_length,
            num_images_per_prompt=request.num_images_per_prompt,
        ).images

        generation_time = time.time() - start_time
        logger.info(f"Image generation completed in {generation_time:.2f} seconds")
        logger.info(f"Generated {len(images)} images")

        encoded_images = encode_images(images)

        return GenerationResponse(images=encoded_images, seed=request.seed)

    except Exception as e:
        logger.error(f"Error during image generation: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

Let’s break down the different steps used in the example above:

Model loading

We’re using FluxPipeline as the baseline model before optimization.

from diffusers import FluxPipeline
import torch

# Load the model
self.pipe = FluxPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    torch_dtype=torch.bfloat16,
).to(self.device)

Configure Pruna for Optimization

Right after this, we create a SmashConfig object that specifies how the model should be optimized. pruna allows you to customize parameters like caching and compilation.

# Configure Pruna Smash
smash_config = SmashConfig()
smash_config["cacher"] = "taylor_auto"
smash_config["compiler"] = "torch_compile"
smash_config._prepare_saving = False

Optimize the Model

Pass your model and configuration to the smash() function, which applies the optimizations.

# Optimize the model

self.pipe = smash(
    model=base_pipe,
    token=os.getenv("PRUNA_API_KEY"), # Provide your actual token if you have purchased one using the `PRUNA_API_KEY` environment variable
    smash_config=smash_config,
)

Use the Optimized Model

After optimization, the model is ready for prediction.

# Generate output
images = pipeline(
    prompt=request.prompt,
    prompt_2=request.prompt_2,
    height=request.height,
    width=request.width,
    num_inference_steps=request.num_inference_steps,
    guidance_scale=request.guidance_scale,
    max_sequence_length=request.max_sequence_length,
    num_images_per_prompt=request.num_images_per_prompt,
).images

Deploying the Optimized Model on Koyeb

First create a Dockerfile within your project current directory with the following content:

FROM nvidia/cuda:12.8.1-cudnn-runtime-ubuntu22.04

COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/

ENV DEBIAN_FRONTEND=noninteractive \
    UV_COMPILE_BYTECODE=1 \
    UV_LINK_MODE=copy \
    HF_HOME=/workspace/model-cache \
    HF_HUB_ENABLE_HF_TRANSFER=1 \
    PATH="/workspace/.venv/bin:$PATH" \
    PORT=8000

RUN apt-get update && \
    apt-get install -y build-essential && \
    apt-get clean && \
    rm -rf /var/lib/apt/lists/*

WORKDIR /workspace

COPY . ./

RUN uv python pin 3.12.0 && \
    uv sync

ENTRYPOINT uvicorn server:app --host 0.0.0.0 --port ${PORT:-8000}

You can deploy to Koyeb using their control panel or via the Koyeb CLI. In this guide, we will deploy using the CLI.

koyeb deploy . pruna-on-koyeb \
   --instance-type gpu-nvidia-l40s \
   --region na \
   --type web \
   --port 8000:http \
   --archive-builder \
   --env PRUNA_API_KEY=`your_pruna_api_key`

After a few seconds, your service will be deployed and running on Koyeb and you will be able to perform your first predictions using your Koyeb domain ending with .koyeb.app. Check out the /docs endpoint to access the documentation on how to run your first prediction.

Ending notes

Congrats! You deployed your optimized model to Koyeb! Combining Pruna’s optimizations with Koyeb’s serverless deployments gives you the best of both worlds—high-performance, cost-efficiency, and scalability all in one!