Skip to content

Commit e68ca4a

Browse files
authored
[CI]test: add wan22 i2v video similarity e2e (vllm-project#2262)
Signed-off-by: David Chen <530634352@qq.com>
1 parent 0cbd1cb commit e68ca4a

6 files changed

Lines changed: 927 additions & 0 deletions

File tree

.buildkite/test-nightly.yml

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,48 @@ steps:
355355
path: /mnt/hf-cache
356356
type: DirectoryOrCreate
357357

358+
- label: ":full_moon: Wan22 I2V Accuracy Test with H100"
359+
key: nightly-wan22-i2v-accuracy
360+
timeout_in_minutes: 180
361+
depends_on: upload-nightly-pipeline
362+
if: build.env("NIGHTLY") == "1" || build.pull_request.labels includes "nightly-test"
363+
commands:
364+
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
365+
- pytest -s -v tests/e2e/accuracy/wan22_i2v/test_wan22_i2v_video_similarity.py --run-level advanced_model
366+
agents:
367+
queue: "mithril-h100-pool"
368+
plugins:
369+
- kubernetes:
370+
podSpec:
371+
containers:
372+
- image: 936637512419.dkr.ecr.us-west-2.amazonaws.com/vllm-ci-pull-through-cache/q9t5s3a7/vllm-ci-test-repo:c392ce21e9cf9ea65c52b866447793db10e0261c
373+
resources:
374+
limits:
375+
nvidia.com/gpu: 2
376+
volumeMounts:
377+
- name: devshm
378+
mountPath: /dev/shm
379+
- name: hf-cache
380+
mountPath: /root/.cache/huggingface
381+
env:
382+
- name: HF_HOME
383+
value: /root/.cache/huggingface
384+
- name: HF_TOKEN
385+
valueFrom:
386+
secretKeyRef:
387+
name: hf-token-secret
388+
key: token
389+
nodeSelector:
390+
node.kubernetes.io/instance-type: gpu-h100-sxm
391+
volumes:
392+
- name: devshm
393+
emptyDir:
394+
medium: Memory
395+
- name: hf-cache
396+
hostPath:
397+
path: /mnt/hf-cache
398+
type: DirectoryOrCreate
399+
358400
- label: ":full_moon: Qwen-Image Diffusion Perf Test with H100"
359401
key: nightly-qwen-image-performance
360402
timeout_in_minutes: 300

tests/e2e/accuracy/conftest.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ def pytest_addoption(parser):
4747
help="Balanced sample count per GEdit task group",
4848
)
4949
group.addoption("--accuracy-workers", action="store", type=int, default=1, help="Worker count for accuracy benches")
50+
group.addoption(
51+
"--wan22-i2v-image-source",
52+
action="store",
53+
default=None,
54+
help="Image source for Wan2.2 I2V accuracy tests. Can be local path or remote URL.",
55+
)
56+
group.addoption(
57+
"--wan22-i2v-online-timeout-seconds",
58+
action="store",
59+
type=int,
60+
default=1200,
61+
help="Online serving timeout in seconds for Wan2.2 I2V accuracy tests.",
62+
)
5063

5164

5265
def _hf_cache_root() -> Path:
@@ -142,6 +155,17 @@ def accuracy_workers(request: pytest.FixtureRequest) -> int:
142155
return int(request.config.getoption("accuracy_workers"))
143156

144157

158+
@pytest.fixture(scope="session")
159+
def wan22_i2v_image_source(request: pytest.FixtureRequest) -> str | None:
160+
value = request.config.getoption("wan22_i2v_image_source")
161+
return str(value) if value else None
162+
163+
164+
@pytest.fixture(scope="session")
165+
def wan22_i2v_online_timeout_seconds(request: pytest.FixtureRequest) -> int:
166+
return int(request.config.getoption("wan22_i2v_online_timeout_seconds"))
167+
168+
145169
@pytest.fixture(scope="session")
146170
def gebench_samples_per_type(request: pytest.FixtureRequest) -> int:
147171
return int(request.config.getoption("gebench_samples_per_type"))

tests/e2e/accuracy/wan22_i2v/__init__.py

Whitespace-only changes.
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
from __future__ import annotations
2+
3+
import argparse
4+
import base64
5+
import json
6+
from io import BytesIO
7+
from pathlib import Path
8+
9+
import requests
10+
import torch
11+
from diffusers import UniPCMultistepScheduler, WanImageToVideoPipeline
12+
from diffusers.pipelines.wan import pipeline_wan_i2v as wan_i2v_module
13+
from diffusers.utils import export_to_video, load_image
14+
from PIL import Image
15+
16+
from tests.e2e.accuracy.wan22_i2v.wan22_i2v_video_similarity_common import BOUNDARY_RATIO
17+
18+
19+
def _parse_args() -> argparse.Namespace:
20+
parser = argparse.ArgumentParser(description="Run Wan2.2 I2V diffusers offline generation.")
21+
parser.add_argument("--model", required=True)
22+
parser.add_argument("--image-source", required=True)
23+
parser.add_argument("--prompt", required=True)
24+
parser.add_argument("--negative-prompt", required=True)
25+
parser.add_argument("--size", required=True)
26+
parser.add_argument("--fps", type=int, required=True)
27+
parser.add_argument("--num-frames", type=int, required=True)
28+
parser.add_argument("--guidance-scale", type=float, required=True)
29+
parser.add_argument("--guidance-scale-2", type=float, required=True)
30+
parser.add_argument("--flow-shift", type=float, required=True)
31+
parser.add_argument("--num-inference-steps", type=int, required=True)
32+
parser.add_argument("--seed", type=int, required=True)
33+
parser.add_argument("--output", required=True)
34+
parser.add_argument("--metadata-output", required=True)
35+
return parser.parse_args()
36+
37+
38+
def _parse_size(size: str) -> tuple[int, int]:
39+
width_str, height_str = size.lower().split("x", 1)
40+
return int(width_str), int(height_str)
41+
42+
43+
class _IdentityFtfy:
44+
@staticmethod
45+
def fix_text(text: str) -> str:
46+
return text
47+
48+
49+
def _ensure_wan_ftfy_fallback() -> None:
50+
if not hasattr(wan_i2v_module, "ftfy"):
51+
wan_i2v_module.ftfy = _IdentityFtfy()
52+
53+
54+
def _offline_cuda_device() -> torch.device:
55+
return torch.device("cuda:0")
56+
57+
58+
def _load_input_image(source: str) -> Image.Image:
59+
if source.startswith("data:image"):
60+
_, encoded = source.split(",", 1)
61+
image = Image.open(BytesIO(base64.b64decode(encoded)))
62+
image.load()
63+
return image.convert("RGB")
64+
65+
source_path = Path(source)
66+
if source_path.exists():
67+
image = Image.open(source_path)
68+
image.load()
69+
return image.convert("RGB")
70+
71+
image = load_image(source)
72+
if isinstance(image, Image.Image):
73+
image.load()
74+
return image.convert("RGB")
75+
76+
response = requests.get(source, timeout=60)
77+
response.raise_for_status()
78+
image = Image.open(BytesIO(response.content))
79+
image.load()
80+
return image.convert("RGB")
81+
82+
83+
def _resize_to_target(image: Image.Image, *, width: int, height: int) -> Image.Image:
84+
return image.resize((width, height), Image.Resampling.LANCZOS)
85+
86+
87+
def _configure_scheduler(pipe: WanImageToVideoPipeline, *, flow_shift: float) -> None:
88+
pipe.scheduler = UniPCMultistepScheduler.from_config(
89+
pipe.scheduler.config,
90+
flow_shift=flow_shift,
91+
)
92+
93+
94+
def _write_metadata(
95+
path: Path,
96+
*,
97+
args: argparse.Namespace,
98+
width: int,
99+
height: int,
100+
frame_count: int,
101+
) -> None:
102+
payload = {
103+
"model": args.model,
104+
"image_source": args.image_source,
105+
"size": args.size,
106+
"width": width,
107+
"height": height,
108+
"fps": args.fps,
109+
"num_frames": args.num_frames,
110+
"actual_frame_count": frame_count,
111+
"guidance_scale": args.guidance_scale,
112+
"guidance_scale_2": args.guidance_scale_2,
113+
"boundary_ratio": BOUNDARY_RATIO,
114+
"flow_shift": args.flow_shift,
115+
"num_inference_steps": args.num_inference_steps,
116+
"seed": args.seed,
117+
"world_size": 1,
118+
}
119+
path.parent.mkdir(parents=True, exist_ok=True)
120+
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
121+
122+
123+
def main() -> int:
124+
args = _parse_args()
125+
device = _offline_cuda_device()
126+
torch.cuda.set_device(device)
127+
_ensure_wan_ftfy_fallback()
128+
129+
pipe = WanImageToVideoPipeline.from_pretrained(args.model, torch_dtype=torch.bfloat16)
130+
pipe.register_to_config(boundary_ratio=BOUNDARY_RATIO)
131+
_configure_scheduler(pipe, flow_shift=args.flow_shift)
132+
pipe.to(device)
133+
pipe.set_progress_bar_config(disable=False)
134+
135+
input_image = _load_input_image(args.image_source)
136+
width, height = _parse_size(args.size)
137+
resized_image = _resize_to_target(input_image, width=width, height=height)
138+
139+
generator = torch.Generator(device=device.type).manual_seed(args.seed)
140+
frames = pipe(
141+
image=resized_image,
142+
prompt=args.prompt,
143+
negative_prompt=args.negative_prompt,
144+
height=height,
145+
width=width,
146+
num_frames=args.num_frames,
147+
guidance_scale=args.guidance_scale,
148+
guidance_scale_2=args.guidance_scale_2,
149+
num_inference_steps=args.num_inference_steps,
150+
generator=generator,
151+
).frames[0]
152+
153+
output_path = Path(args.output)
154+
metadata_path = Path(args.metadata_output)
155+
output_path.parent.mkdir(parents=True, exist_ok=True)
156+
export_to_video(frames, str(output_path), fps=args.fps)
157+
_write_metadata(metadata_path, args=args, width=width, height=height, frame_count=len(frames))
158+
159+
if hasattr(pipe, "maybe_free_model_hooks"):
160+
pipe.maybe_free_model_hooks()
161+
return 0
162+
163+
164+
if __name__ == "__main__":
165+
raise SystemExit(main())

0 commit comments

Comments
 (0)