Skip to content

Commit 7999503

Browse files
Usernamefaraaz-bot
authored andcommitted
Adding reinforcement learning blog
1 parent febcc6b commit 7999503

14 files changed

+1619
-0
lines changed

.authorlist.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,8 @@ Pauli
320320
Pihajoki
321321
Pei
322322
Zhang
323+
Peng
324+
Sun
323325
Phani
324326
Vaddadi
325327
Philipp

blogs/artificial-intelligence/wan-flow-grpo/README.md

Lines changed: 525 additions & 0 deletions
Large diffs are not rendered by default.
644 KB
Loading
60.6 KB
Loading
2.12 MB
Loading
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import torch
2+
import numpy as np
3+
from diffusers import WanPipeline
4+
from diffusers.utils import export_to_video
5+
from peft import PeftModel
6+
from accelerate import Accelerator
7+
8+
# Initialize accelerator
9+
accelerator = Accelerator(
10+
mixed_precision="bf16", # can be "no", "fp16", or "bf16"
11+
device_placement=True
12+
)
13+
14+
# --- Paths ---
15+
pretrained_model = '/workspace/hf_cache/Wan2.1-T2V-14B-Diffusers'
16+
lora_path = '/workspace/logs/video_ocr/wan_flow_grpo_14B/checkpoints/checkpoint-176/lora'
17+
18+
# --- Load pipeline ---
19+
print("Loading pipeline...")
20+
pipeline = WanPipeline.from_pretrained(
21+
pretrained_model,
22+
torch_dtype=torch.bfloat16,
23+
)
24+
25+
# Disable gradient computations for inference
26+
pipeline.vae.requires_grad_(False)
27+
pipeline.text_encoder.requires_grad_(False)
28+
pipeline.transformer.requires_grad_(False)
29+
30+
pipeline.safety_checker = None
31+
32+
# Load LoRA fine-tuned weights (optional)
33+
print("Loading LoRA adapter...")
34+
pipeline.transformer = PeftModel.from_pretrained(pipeline.transformer, lora_path)
35+
pipeline.transformer.set_adapter('default')
36+
37+
# --- Prepare model for distributed inference ---
38+
pipeline.vae.to(accelerator.device, dtype=torch.float32)
39+
pipeline.text_encoder.to(accelerator.device, dtype=torch.bfloat16)
40+
pipeline.transformer.to(accelerator.device)
41+
42+
# Accelerator handles device wrapping automatically
43+
pipeline = accelerator.prepare(pipeline)
44+
45+
# --- Prompts ---
46+
prompt = (
47+
"A vast shot opens in the dark expanse of space, scattered with distant stars and a faint red hue from the "
48+
"planet Mars. The camera slowly glides past a futuristic spaceship, its sleek metallic hull reflecting the "
49+
"starlight in soft gradients. Subtle thruster lights pulse along its surface as the camera tilts to reveal "
50+
"bold markings etched across the side: “Mars Colony One.” The letters gleam under the distant cosmic glow. "
51+
"The camera pulls back gradually, capturing the ship’s immense scale as it drifts silently through the "
52+
"endless void."
53+
)
54+
55+
negative_prompt = (
56+
"Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, "
57+
"static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, "
58+
"poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, "
59+
"messy background, three legs, many people in the background, walking backwards"
60+
)
61+
62+
# --- Inference ---
63+
if accelerator.is_main_process:
64+
print("Starting distributed inference...")
65+
66+
output = pipeline(
67+
prompt=prompt,
68+
negative_prompt=negative_prompt,
69+
height=480,
70+
width=832,
71+
num_frames=121,
72+
num_inference_steps=50,
73+
guidance_scale=5.0
74+
).frames[0]
75+
76+
# Save video only on main process to avoid race conditions
77+
if accelerator.is_main_process:
78+
export_to_video(output, 'output_wan21_14b.mp4', fps=24)
79+
80+
accelerator.wait_for_everyone()
81+
print("Inference complete.")

0 commit comments

Comments
 (0)