Optimize Flux2 Klein (4B) Image Generation
Component |
Details |
|---|---|
Goal |
Demonstrate optimizing and evaluating Flux2 Klein 4B with FORA, quantization, and torch compile |
Model |
|
Optimization Algorithms |
cacher(fora), quantizer(torchao fp8), compiler(torch_compile) |
Evaluation |
Baseline vs optimized latency comparison |
This tutorial demonstrates how to use Pruna to speed up image generation with Flux2 using a combination of three optimization techniques:
FORA (Fast Output Reuse Acceleration) - Caches transformer block outputs and reuses them for subsequent diffusion steps
TorchAO Quantization (FP8) - Reduces memory bandwidth by using 8-bit floating point weights
Torch Compile - JIT compiles the model for optimized GPU execution
Together, these optimizations can achieve 2-3x speedup while maintaining image quality.
Prerequisites
NVIDIA GPU with CUDA support (compute capability ≥ 8.9 for FP8). Note that FP8 quantization is hardware-specific and requires modern GPUs such as H100.
prunalibrary installed (pip install pruna)diffuserswith Flux2 supportHuggingFace account with access to Flux2 models
Getting Started
To install the dependencies, run the following command:
[ ]:
%pip install pruna
[ ]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
The device is set to the best available option to maximize the benefits of the optimization process.
1. Load the Model
Before optimizing the model, we load the Flux2-Klein-Base-4B pipeline. You need a HuggingFace account with access to the model; run the login cell below with your token. Do not use enable_model_cpu_offload() as it interferes with FORA’s caching mechanism.
[ ]:
# Login to HuggingFace (required for Flux2 model access)
from huggingface_hub import login
# Replace with your HuggingFace token
HF_TOKEN = "your_hf_token_here"
login(token=HF_TOKEN)
[ ]:
import torch
import copy
import gc
import matplotlib.pyplot as plt
from diffusers import Flux2KleinPipeline
from pruna import SmashConfig, smash
# Check CUDA availability
device = "cuda" if torch.cuda.is_available() else "cpu"
We load the Flux2-Klein-Base-4B pipeline. This is a 4 billion parameter model optimized for high-quality image generation.
[ ]:
print("Loading Flux2 pipeline...")
pipe = Flux2KleinPipeline.from_pretrained(
"black-forest-labs/FLUX.2-klein-base-4B",
torch_dtype=torch.bfloat16,
)
pipe = pipe.to(device)
print("Pipeline loaded!")
# NOTE: Do NOT use pipe.enable_model_cpu_offload()
# It breaks FORA's caching mechanism by re-initializing model hooks
2. Define Generation Parameters
Set up the common parameters for image generation. We use the same parameters for both baseline and optimized models to ensure a fair comparison.
[ ]:
# Generation parameters
prompt = "A humanoid robot or cyborg set against a bold red background"
gen_kwargs = dict(
image=None,
prompt=prompt,
height=1024,
width=1024,
guidance_scale=4.0,
num_inference_steps=50,
)
# Use same seed for reproducible results
def get_generator():
return torch.Generator(device=device).manual_seed(42)
# FORA parameters
FORA_START_STEP = 4 # Start caching after this many steps
FORA_INTERVAL = 3 # Recompute every N steps (higher = faster but lower quality)
# Benchmarking parameters
NUM_WARMUP = 2 # Warmup runs (not timed)
NUM_TIMED = 3 # Timed runs for averaging
print(f"Prompt: {prompt}")
print(f"Image size: {gen_kwargs['height']}x{gen_kwargs['width']}")
print(f"Inference steps: {gen_kwargs['num_inference_steps']}")
print(f"FORA start_step: {FORA_START_STEP}, interval: {FORA_INTERVAL}")
3. Benchmark Baseline (No Optimization)
First, we measure the baseline performance without any optimization.
[ ]:
# Move pipeline to GPU
pipe.to("cuda")
# ============================================================
# BASELINE: Warmup then Measure
# ============================================================
print(f"--- Warming up BASELINE ({NUM_WARMUP} runs) ---")
for i in range(NUM_WARMUP):
print(f" Warmup {i+1}/{NUM_WARMUP}...")
_ = pipe(**gen_kwargs, generator=get_generator())
torch.cuda.synchronize()
print("Baseline warmup complete.")
print(f"\n--- Timing BASELINE ({NUM_TIMED} runs) ---")
baseline_times = []
for i in range(NUM_TIMED):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
image_baseline = pipe(**gen_kwargs, generator=get_generator()).images[0]
end_event.record()
torch.cuda.synchronize()
elapsed = start_event.elapsed_time(end_event) / 1000 # ms to seconds
baseline_times.append(elapsed)
print(f" Run {i+1}: {elapsed:.2f}s")
avg_baseline = sum(baseline_times) / len(baseline_times)
print(f"\nBaseline average: {avg_baseline:.2f}s")
4. Define SmashConfig and Smash the Model
Now we create an optimized model using Pruna’s smash function with three techniques:
FORA Parameters
``fora_interval``: How often to recompute transformer outputs (higher = faster, but lower quality)
``fora_start_step``: Number of initial steps to run without caching (higher = better quality early steps)
``fora_backbone_calls_per_step``: Set to 2 for Classifier-Free Guidance (CFG)
TorchAO Quantization
``quant_type: fp8wo``: FP8 weight-only quantization - reduces memory bandwidth while keeping activations in full precision
``target_modules``: Specify which layers to quantize (we target
single_transformer_blocksand exclude norms/embeddings)
Torch Compile
``torch_compile: True``: Enables JIT compilation for optimized GPU kernels
[ ]:
# ============================================================
# Create optimized model with FORA + Quantization + Compile
# ============================================================
print("--- Creating optimized model ---")
smash_config = SmashConfig({
# FORA: Cache and reuse transformer outputs
"fora": {
"fora_interval": FORA_INTERVAL,
"fora_start_step": FORA_START_STEP,
"fora_backbone_calls_per_step": 2, # 2 for CFG (conditional + unconditional)
},
# TorchAO: FP8 weight-only quantization for reduced memory bandwidth
"torchao": {
"quant_type": "fp8wo",
"target_modules": {
"include": ["*single_transformer_blocks.*"], # Quantize single-stream blocks
"exclude": ["pe_embedder", "*norm*", "*embed*"] # Skip norms and embeddings
},
},
# Torch Compile: JIT compilation for optimized kernels
"torch_compile": True
})
# Create a deep copy of the pipeline and apply optimizations
smashed_model = smash(model=copy.deepcopy(pipe), smash_config=smash_config)
print("Optimized model created!")
5. Benchmark Optimized Model
The first warmup run may be slower due to torch.compile JIT compilation. Subsequent runs show the optimized performance.
[ ]:
# ============================================================
# Optimized Model: Warmup then Measure
# ============================================================
print(f"--- Warming up optimized model ({NUM_WARMUP} runs) ---")
print("(First run includes torch.compile JIT compilation overhead)")
for i in range(NUM_WARMUP):
print(f" Warmup {i+1}/{NUM_WARMUP}...")
_ = smashed_model(**gen_kwargs, generator=get_generator())
torch.cuda.synchronize()
print("Warmup complete.")
print(f"\n--- Timing optimized model ({NUM_TIMED} runs) ---")
fora_times = []
for i in range(NUM_TIMED):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start_event.record()
image_fora = smashed_model(**gen_kwargs, generator=get_generator()).images[0]
end_event.record()
torch.cuda.synchronize()
elapsed = start_event.elapsed_time(end_event) / 1000 # ms to seconds
fora_times.append(elapsed)
print(f" Run {i+1}: {elapsed:.2f}s")
avg_fora = sum(fora_times) / len(fora_times)
print(f"\nOptimized average: {avg_fora:.2f}s")
6. Results Comparison
[ ]:
# ============================================================
# Results Summary
# ============================================================
speedup = avg_baseline / avg_fora
print(f"\n{'='*60}")
print(f"RESULTS: FORA + FP8 Quantization + Torch Compile")
print(f"{'='*60}")
print(f"FORA: start_step={FORA_START_STEP}, interval={FORA_INTERVAL}")
print(f"Quantization: FP8 weight-only on single_transformer_blocks")
print(f"{'='*60}")
print(f"Baseline: {avg_baseline:.2f}s")
print(f"Optimized: {avg_fora:.2f}s")
print(f"Speedup: {speedup:.2f}x")
print(f"Time saved: {avg_baseline - avg_fora:.2f}s per image")
print(f"{'='*60}")
[ ]:
# ============================================================
# Visual Comparison
# ============================================================
fig, axes = plt.subplots(1, 2, figsize=(14, 7))
axes[0].imshow(image_baseline)
axes[0].set_title(f"Baseline\n{avg_baseline:.2f}s", fontsize=14)
axes[0].axis("off")
axes[1].imshow(image_fora)
axes[1].set_title(f"Optimized (FORA + FP8 + Compile)\n{avg_fora:.2f}s ({speedup:.2f}x faster)", fontsize=14)
axes[1].axis("off")
plt.suptitle(f'Prompt: "{prompt}"', fontsize=12, y=1.02)
plt.tight_layout()
plt.savefig("baseline_vs_fora.png", dpi=150, bbox_inches="tight")
plt.show()
7. Cleanup
[ ]:
# Free GPU memory
del smashed_model
torch.cuda.empty_cache()
gc.collect()
Conclusion
In this tutorial, we demonstrated a workflow for optimizing and evaluating the Flux2 Klein 4B image generation model using Pruna. We defined a SmashConfig combining FORA (cacher), TorchAO FP8 quantization, and torch compile, applied it with smash, and compared baseline vs optimized latency. The results show that these optimizations can achieve significant speedup while maintaining image quality. You can adapt the configuration to your use case or reach out on
Discord for questions.