P-Image-Edit LoRA: Training and Inference
This notebook is the full walkthrough for P-Image-Edit LoRA Training and Inference: we train a LoRA adapter for image editing with p-image-edit-trainer and run inference with p-image-edit-lora.
Two approaches — don’t mix them:
Approach |
Use case |
Training |
Inference |
|---|---|---|---|
P-Image-Edit LoRA |
Image editing (input image → edited image) |
|
|
P-Image LoRA |
Text-to-image (prompt → image) |
|
|
This notebook covers P-Image-Edit LoRA only. For P-Image LoRA (text-to-image), see the LoRA documentation for dataset format and endpoints.
In this notebook we will:
Generate 200 training image pairs using Flux-2-Klein models
Prepare the dataset in the correct format for
`p-image-edit-trainer<https://replicate.com/prunaai/p-image-edit-trainer>`__Upload the dataset to HuggingFace
Train a LoRA adapter (P-Image-Edit LoRA training) with
`p-image-edit-trainer<https://replicate.com/prunaai/p-image-edit-trainer>`__Download and extract the trained LoRA weights
Upload the LoRA weights to HuggingFace
Test inference (P-Image-Edit LoRA inference) with
p-image-edit-lora
Setup
Install required packages and set up authentication.
[ ]:
% pip install replicate huggingface-hub pillow tqdm datasets
Running cells with '.venv (Python 3.12.8)' requires the ipykernel package.
Install 'ipykernel' into the Python environment.
Command: '/Users/davidberenstein/Documents/programming/pruna/prunatree/pruna/.venv/bin/python -m pip install ipykernel -U --force-reinstall'
[ ]:
import os
import io
import zipfile
import time
import requests
import random
from pathlib import Path
import requests
from typing import Iterator
from IPython.display import Image, display
from tqdm import tqdm
from datasets import load_dataset
from replicate.client import Client
from huggingface_hub import HfApi, upload_file
from PIL import Image as PILImage
replicate_token = os.environ.get("REPLICATE_API_TOKEN")
if not replicate_token:
replicate_token = input("Replicate API token (r8_...): ").strip()
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
hf_token = input("HuggingFace API token (hf_...): ").strip()
replicate = Client(api_token=replicate_token)
hf_api = HfApi(token=hf_token)
/Users/davidberenstein/Documents/programming/pruna/prunatree/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
Step 1: Generate Training Pairs
We’ll generate 200 image pairs using black-forest-labs/flux-2-klein-4b and black-forest-labs/flux-2-klein-9b. Both models will use the same seed to generate similar images, with the 4B model output as input (_start) and the 9B model output as target (_end).
Note: We assume that images generated by the 9B model will generally be more appealing and higher quality than the 4B model, making them suitable as training targets for enhancement.
Important: This approach is not foolproof. The quality difference between models may vary, and the generated pairs may not always represent ideal enhancement targets. For production use, consider using curated, high-quality training pairs.
[ ]:
n_samples = 200
dataset_stream = load_dataset(
"data-is-better-together/open-image-preferences-v1",
split="cleaned",
streaming=True,
)
# Collect just n_samples prompts, filtering for non-empty "simplified_prompt"
streamed_prompts = []
for item in dataset_stream:
if item.get("simplified_prompt"):
streamed_prompts.append(item["simplified_prompt"])
if len(streamed_prompts) >= n_samples:
break
random.seed(42)
prompts = streamed_prompts # already exactly n_samples, but keep as list
print(f"Loaded {len(prompts)} prompts from streaming dataset")
Now, let’s generate the image pairs. Because we are fine-tning
[ ]:
def _fetch_image_bytes(output):
"""
Given output from replicate.run, fetches the actual image bytes.
- If output is a file-like object, calls .read()
- If output is a list of URLs, downloads the first URL
- If output is a single URL string, downloads it directly
"""
if hasattr(output, "read"):
return output.read()
if isinstance(output, list):
# Replicate typically returns a list of URLs (even length-1).
url = output[0]
elif isinstance(output, str):
url = output
else:
raise ValueError(f"Unexpected output type: {type(output)}")
resp = requests.get(url)
resp.raise_for_status()
return resp.content
def generate_image_pair(
prompt: str,
seed: int,
model_4b: str = "black-forest-labs/flux-2-klein-4b",
model_9b: str = "black-forest-labs/flux-2-klein-9b",
) -> tuple[bytes, bytes]:
"""Generate a pair of images using flux-2-klein-4b and flux-2-klein-9b.
Args:
prompt: Text prompt for image generation
seed: Random seed (same for both models)
model_4b: Replicate model identifier for 4B model
model_9b: Replicate model identifier for 9B model
Returns:
Tuple of (image_4b_bytes, image_9b_bytes)
"""
output_4b = replicate.run(model_4b, input={"prompt": prompt, "seed": seed})
output_9b = replicate.run(model_9b, input={"prompt": prompt, "seed": seed})
img_4b_bytes = _fetch_image_bytes(output_4b)
img_9b_bytes = _fetch_image_bytes(output_9b)
return img_4b_bytes, img_9b_bytes
image_pairs = []
for i, prompt in enumerate(tqdm(prompts, desc="Generating pairs")):
seed = i
try:
img_4b, img_9b = generate_image_pair(prompt, seed)
image_pairs.append((img_4b, img_9b))
except Exception as e:
print(f"Error generating pair {i}: {e}")
continue
print(f"Generated {len(image_pairs)} image pairs")
Let’s display a sample pair to verify the generation:
[ ]:
if image_pairs:
from IPython.display import display, HTML
import base64
def b64_img(img_bytes):
return f"<img src='data:image/png;base64,{base64.b64encode(img_bytes).decode()}' style='max-width: 300px; border:1px solid #ccc;'/>"
n = max(5, len(image_pairs))
print(f"Displaying {n} sample pairs (4B model → 9B model):")
html = """
<table>
<tr>
<th style='text-align:center;'>Pair #</th>
<th style='text-align:center;'>4B model</th>
<th style='text-align:center;'>9B model</th>
</tr>
"""
for i in range(n):
img_4b, img_9b = image_pairs[i]
html += f"""
<tr>
<td style='text-align:center; font-weight:bold;'>{i+1}</td>
<td>{b64_img(img_4b)}</td>
<td>{b64_img(img_9b)}</td>
</tr>
"""
html += "</table>"
display(HTML(html))
Displaying 5 sample pairs (4B model → 9B model):
| Pair # | 4B model | 9B model |
|---|---|---|
| 1 | ||
| 2 | ||
| 3 | ||
| 4 | ||
| 5 |
In general, the photos by FLUX.2 [klein] 9B are more detailed and have more fine details than the photos by FLUX.2 [klein] 4B, so there should be some room for improvement.
Step 2: Prepare Dataset (P-Image-Edit LoRA)
Format the image pairs according to p-image-edit-trainer requirements: ROOT_start.EXT and ROOT_end.EXT pairs. We’ll also add optional caption files (ROOT.txt) with the trigger word sks_enhance to help the model learn the enhancement behavior. For P-Image LoRA (text-to-image), the dataset format is different — see the LoRA documentation.
[10]:
def create_dataset_zip(
image_pairs: list[tuple[bytes, bytes]], output_path: str = "dataset.zip"
) -> str:
"""Create a ZIP file with image pairs in the correct format.
Args:
image_pairs: List of (start_image_bytes, end_image_bytes) tuples
output_path: Path for the output ZIP file
Returns:
Path to the created ZIP file
"""
trigger_word = "sks_enhance"
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for i, (img_start, img_end) in enumerate(image_pairs):
base_name = f"pair_{i:03d}"
zipf.writestr(f"{base_name}_start.png", img_start)
zipf.writestr(f"{base_name}_end.png", img_end)
zipf.writestr(f"{base_name}.txt", trigger_word.encode("utf-8"))
return output_path
dataset_zip_path = create_dataset_zip(image_pairs)
print(f"Created dataset ZIP: {dataset_zip_path}")
print(f"Dataset size: {os.path.getsize(dataset_zip_path) / 1024 / 1024:.2f} MB")
Created dataset ZIP: dataset.zip
Dataset size: 2.54 MB
Step 3: Upload Dataset to HuggingFace
Upload the dataset ZIP to HuggingFace. You’ll need to create a repository first (e.g., your-username/lora-dataset).
[ ]:
hf_dataset_repo = "davidberenstein1957/enhance"
upload_file(
path_or_fileobj=dataset_zip_path,
path_in_repo="input.zip",
repo_id=hf_dataset_repo,
repo_type="dataset",
token=hf_token,
)
hf_dataset_url = (
f"https://huggingface.co/datasets/{hf_dataset_repo}/resolve/main/input.zip"
)
print(f"Dataset uploaded to: {hf_dataset_url}")
Step 4: P-Image-Edit LoRA Training
Start P-Image-Edit LoRA training with p-image-edit-trainer using custom hyperparameters: epochs=1, learning_rate=0.00001 and steps=2000. For a full overview of hyperparameters, see the LoRA documentation.
[ ]:
training_input = {
"image_data": hf_dataset_url,
"learning_rate": 0.00001,
"steps": 2000,
"epochs": 1,
}
print("Starting training...")
prediction = replicate.predictions.create(
model="prunaai/p-image-edit-trainer:8b6274e55245fe89fb22d63cd7bb1cbbf58222716bfac9750132cdcd2491d3ac",
input=training_input,
)
print(f"Training started. Prediction ID: {prediction.id}")
print(f"Monitor at: https://replicate.com/p/{prediction.id}")
print("Waiting for training to complete (this may take 30-60 minutes)...")
prediction.wait()
if prediction.status != "succeeded":
raise Exception(
f"Training {prediction.status}: {getattr(prediction, 'error', 'Unknown error')}"
)
print("Training completed!")
output = prediction.output
if hasattr(output, "url"):
lora_output_url = output.url
elif isinstance(output, str):
lora_output_url = output
else:
lora_output_url = str(output)
print(f"LoRA output URL: {lora_output_url}")
Step 5: Download and Extract LoRA Weights
Download the training output ZIP, extract the LoRA weights file, and upload it to HuggingFace.
[ ]:
response = requests.get(lora_output_url)
response.raise_for_status()
lora_zip_path = "lora_output.zip"
with open(lora_zip_path, "wb") as f:
f.write(response.content)
print(f"Downloaded LoRA output ZIP: {lora_zip_path}")
with zipfile.ZipFile(lora_zip_path, "r") as zipf:
file_list = zipf.namelist()
print(f"Files in ZIP: {file_list}")
lora_file = next((f for f in file_list if f.endswith(".safetensors")), None)
if not lora_file:
raise ValueError("No .safetensors file found in ZIP")
lora_weights_bytes = zipf.read(lora_file)
print(f"Extracted LoRA file: {lora_file}")
lora_weights_path = "weights.safetensors"
with open(lora_weights_path, "wb") as f:
f.write(lora_weights_bytes)
print(f"LoRA weights saved to: {lora_weights_path}")
Step 6: Upload LoRA Weights to HuggingFace
Upload the extracted LoRA weights file to HuggingFace.
[ ]:
hf_lora_repo = "davidberenstein1957/p-image-edit-photo-enhancement-lora"
upload_file(
path_or_fileobj=lora_weights_path,
path_in_repo="weights.safetensors",
repo_id=hf_lora_repo,
repo_type="model",
token=hf_token,
)
hf_lora_url = f"https://huggingface.co/{hf_lora_repo}/resolve/main/weights.safetensors"
print(f"LoRA weights uploaded to: {hf_lora_url}")
Step 7: P-Image-Edit LoRA Inference
Use p-image-edit-lora to test the trained LoRA adapter on a new image (P-Image-Edit LoRA inference).
[17]:
test_prompt = (
"A serene mountain landscape at sunset with snow-capped peaks, "
"enhanced with dramatic lighting and vibrant colors"
)
test_seed = 42
test_image_output = replicate.run(
"prunaai/p-image",
input={"prompt": test_prompt, "seed": test_seed},
)
test_image_bytes = test_image_output.read()
print("Test input image:")
display(Image(data=test_image_bytes))
edit_prompt = "sks_enhance"
edited_output = replicate.run(
"prunaai/p-image-edit-lora",
input={
"prompt": edit_prompt,
"images": [io.BytesIO(test_image_bytes)],
"lora_scale": 1.0,
"hf_token": hf_token,
},
)
edited_image_bytes = edited_output.read()
print("Edited image with LoRA:")
display(Image(data=edited_image_bytes))
Test input image:
Edited image with LoRA:
Example: Stylized generation with Flux Klein 9B (comic noir)
This section uses Flux Klein 9B for stylized text-to-image with no LoRA: we steer the model with a style-specific prompt (comic noir — high contrast, shadows, vintage comic aesthetic). To get a reusable comic noir style you can trigger with a single token, train a P-Image LoRA (text-to-image) with p-image-trainer on a dataset of comic noir images; see the LoRA documentation.
[ ]:
FLUX_KLEIN_9B = "black-forest-labs/flux-2-klein-9b"
comic_noir_prompt = (
"A detective in a trench coat standing in a rainy alley at night, "
"comic noir style, high contrast black and white with bold shadows, "
"halftone dots, vintage pulp magazine aesthetic, dramatic lighting"
)
comic_noir_seed = 123
output = replicate.run(
FLUX_KLEIN_9B, input={"prompt": comic_noir_prompt, "seed": comic_noir_seed}
)
comic_noir_bytes = _fetch_image_bytes(output)
print("Flux Klein 9B — comic noir style:")
display(Image(data=comic_noir_bytes))
Summary
You’ve successfully completed the end-to-end P-Image-Edit LoRA training and LoRA inference workflow:
✅ Generated 200 training pairs using Flux-2-Klein models
✅ Prepared dataset in correct format (
_start/_endpairs) with trigger wordsks_enhance✅ Uploaded dataset to HuggingFace
✅ Trained LoRA adapter with
learning_rate=0.00001andsteps=2000✅ Downloaded and extracted LoRA weights from training output
✅ Uploaded LoRA weights to HuggingFace
✅ Tested LoRA inference with
p-image-edit-lorausing the trigger word
Note: A similar workflow exists for text-to-image using P-Image LoRA training and P-Image LoRA inference (p-image-trainer and p-image-lora):
The dataset format is simpler (just images with caption files)
Each caption should include a trigger word (e.g.,
sks_enhance) plus you can decide to include an exact prompt describing what you want the model to learn, however this is not required for general image editing tasks.Example for text rendering: If training a LoRA to improve text quality, your captions might be:
sks_enhance, "Hello World" rendered in bold sans-serif fontsks_enhance, "Welcome" in elegant script with decorative flourishesThe trigger word helps the model recognize when to apply the enhancement, while the exact prompt teaches it what enhancement to perform
See the :doc:
/docs_pruna_endpoints/performance_models/loradocumentation for details