|
| 1 | +import torch |
| 2 | +import numpy as np |
| 3 | +from diffusers import WanPipeline |
| 4 | +from diffusers.utils import export_to_video |
| 5 | +from peft import PeftModel |
| 6 | +from accelerate import Accelerator |
| 7 | + |
| 8 | +# Initialize accelerator |
| 9 | +accelerator = Accelerator( |
| 10 | + mixed_precision="bf16", # can be "no", "fp16", or "bf16" |
| 11 | + device_placement=True |
| 12 | +) |
| 13 | + |
| 14 | +# --- Paths --- |
| 15 | +pretrained_model = '/workspace/hf_cache/Wan2.1-T2V-14B-Diffusers' |
| 16 | +lora_path = '/workspace/logs/video_ocr/wan_flow_grpo_14B/checkpoints/checkpoint-176/lora' |
| 17 | + |
| 18 | +# --- Load pipeline --- |
| 19 | +print("Loading pipeline...") |
| 20 | +pipeline = WanPipeline.from_pretrained( |
| 21 | + pretrained_model, |
| 22 | + torch_dtype=torch.bfloat16, |
| 23 | +) |
| 24 | + |
| 25 | +# Disable gradient computations for inference |
| 26 | +pipeline.vae.requires_grad_(False) |
| 27 | +pipeline.text_encoder.requires_grad_(False) |
| 28 | +pipeline.transformer.requires_grad_(False) |
| 29 | + |
| 30 | +pipeline.safety_checker = None |
| 31 | + |
| 32 | +# Load LoRA fine-tuned weights (optional) |
| 33 | +print("Loading LoRA adapter...") |
| 34 | +pipeline.transformer = PeftModel.from_pretrained(pipeline.transformer, lora_path) |
| 35 | +pipeline.transformer.set_adapter('default') |
| 36 | + |
| 37 | +# --- Prepare model for distributed inference --- |
| 38 | +pipeline.vae.to(accelerator.device, dtype=torch.float32) |
| 39 | +pipeline.text_encoder.to(accelerator.device, dtype=torch.bfloat16) |
| 40 | +pipeline.transformer.to(accelerator.device) |
| 41 | + |
| 42 | +# Accelerator handles device wrapping automatically |
| 43 | +pipeline = accelerator.prepare(pipeline) |
| 44 | + |
| 45 | +# --- Prompts --- |
| 46 | +prompt = ( |
| 47 | + "A vast shot opens in the dark expanse of space, scattered with distant stars and a faint red hue from the " |
| 48 | + "planet Mars. The camera slowly glides past a futuristic spaceship, its sleek metallic hull reflecting the " |
| 49 | + "starlight in soft gradients. Subtle thruster lights pulse along its surface as the camera tilts to reveal " |
| 50 | + "bold markings etched across the side: “Mars Colony One.” The letters gleam under the distant cosmic glow. " |
| 51 | + "The camera pulls back gradually, capturing the ship’s immense scale as it drifts silently through the " |
| 52 | + "endless void." |
| 53 | +) |
| 54 | + |
| 55 | +negative_prompt = ( |
| 56 | + "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, " |
| 57 | + "static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, " |
| 58 | + "poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, " |
| 59 | + "messy background, three legs, many people in the background, walking backwards" |
| 60 | +) |
| 61 | + |
| 62 | +# --- Inference --- |
| 63 | +if accelerator.is_main_process: |
| 64 | + print("Starting distributed inference...") |
| 65 | + |
| 66 | +output = pipeline( |
| 67 | + prompt=prompt, |
| 68 | + negative_prompt=negative_prompt, |
| 69 | + height=480, |
| 70 | + width=832, |
| 71 | + num_frames=121, |
| 72 | + num_inference_steps=50, |
| 73 | + guidance_scale=5.0 |
| 74 | +).frames[0] |
| 75 | + |
| 76 | +# Save video only on main process to avoid race conditions |
| 77 | +if accelerator.is_main_process: |
| 78 | + export_to_video(output, 'output_wan21_14b.mp4', fps=24) |
| 79 | + |
| 80 | +accelerator.wait_for_everyone() |
| 81 | +print("Inference complete.") |
0 commit comments