Skip to content

Commit 39047da

Browse files
authored
Add multi-card inference support for Wan pipelines (#2325)
Signed-off-by: Daniel Socek <daniel.socek@intel.com>
1 parent bfbf0d2 commit 39047da

File tree

7 files changed

+331
-83
lines changed

7 files changed

+331
-83
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,13 +300,13 @@ The following model architectures, tasks and device distributions have been vali
300300
| Stable Diffusion | :heavy_check_mark: | :heavy_check_mark: | <ul><li>[text-to-image generation](/examples/stable-diffusion#text-to-image-generation)</li><li>[image-to-image generation](/examples/stable-diffusion#image-to-image-generation)</li></ul> |
301301
| Stable Diffusion XL | :heavy_check_mark: | :heavy_check_mark: | <ul><li>[text-to-image generation](/examples/stable-diffusion#stable-diffusion-xl-sdxl)</li><li>[image-to-image generation](/examples/stable-diffusion#stable-diffusion-xl-refiner)</li></ul> |
302302
| Stable Diffusion Depth2img | | <ul><li>Single card</li></ul> | <ul><li>[depth-to-image generation](/examples/stable-diffusion)</li></ul> |
303-
| Stable Diffusion 3 | :heavy_check_mark: | <ul><li>Single card</li></ul> | <ul><li>[text-to-image generation](/examples/stable-diffusion#stable-diffusion-3-and-35-sd3)</li></ul> |
303+
| Stable Diffusion 3 | :heavy_check_mark: | :heavy_check_mark: | <ul><li>[text-to-image generation](/examples/stable-diffusion#stable-diffusion-3-and-35-sd3)</li></ul> |
304304
| LDM3D | | <ul><li>Single card</li></ul> | <ul><li>[text-to-image generation](/examples/stable-diffusion#text-to-image-generation)</li></ul> |
305305
| FLUX.1 | <ul><li>LoRA</li></ul> | <ul><li>Single card</li></ul> | <ul><li>[text-to-image generation](/examples/stable-diffusion#flux1)</li><li>[image-to-image generation](/examples/stable-diffusion#flux1-image-to-image)</li></ul> |
306306
| Text to Video | | <ul><li>Single card</li></ul> | <ul><li>[text-to-video generation](/examples/stable-diffusion#text-to-video-generation)</li></ul> |
307307
| Image to Video | | <ul><li>Single card</li></ul> | <ul><li>[image-to-video generation](/examples/stable-diffusion#image-to-video-generation)</li></ul> |
308308
| i2vgen-xl | | <ul><li>Single card</li></ul> | <ul><li>[image-to-video generation](/examples/stable-diffusion#I2vgen-xl)</li></ul> |
309-
| Wan | | <ul><li>Single card</li></ul> | <ul><li>[text-to-video generation](/examples/stable-diffusion#text-to-video-with-wan-22)</li><li>[image-to-video generation](/examples/stable-diffusion#image-to-video-with-wan-22)</li></ul> |
309+
| Wan | | :heavy_check_mark: | <ul><li>[text-to-video generation](/examples/stable-diffusion#text-to-video-with-wan-22)</li><li>[image-to-video generation](/examples/stable-diffusion#image-to-video-with-wan-22)</li></ul> |
310310

311311
### PyTorch Image Models/TIMM:
312312

docs/source/index.mdx

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,16 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
122122

123123
| Architecture | Training. | Inference | Tasks |
124124
|----------------------------|:----------------------:|:-----------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
125-
| Stable Diffusion | | | <ul><li>[text-to-image generation](/examples/stable-diffusion)</li></ul> |
126-
| Stable Diffusion XL | | | <ul><li>[text-to-image generation](/examples/stable-diffusion)</li></ul> |
125+
| Stable Diffusion ||| <ul><li>[text-to-image generation](/examples/stable-diffusion)</li></ul> |
126+
| Stable Diffusion XL ||| <ul><li>[text-to-image generation](/examples/stable-diffusion)</li></ul> |
127127
| Stable Diffusion Depth2img | | <ul><li>Single card</li></ul> | <ul><li>[depth-to-image generation](/examples/stable-diffusion)</li></ul> |
128-
| Stable Diffusion 3 | | <ul><li>Single card</li></ul> | <ul><li>[text-to-image generation](/examples/stable-diffusion#stable-diffusion-3-and-35-sd3)</li></ul> |
128+
| Stable Diffusion 3 || <ul><li>Single card</li></ul> | <ul><li>[text-to-image generation](/examples/stable-diffusion#stable-diffusion-3-and-35-sd3)</li></ul> |
129129
| LDM3D | | <ul><li>Single card</li></ul> | <ul><li>[text-to-image generation](/examples/stable-diffusion)</li></ul> |
130-
| FLUX.1 | <ul><li>LoRA</li></ul> | <ul><li>Single card</li></ul> | <ul><li>[text-to-image generation](/examples/stable-diffusion)</li></ul> |
130+
| FLUX.1 | <ul><li>LoRA</li></ul> | | <ul><li>[text-to-image generation](/examples/stable-diffusion)</li></ul> |
131131
| Text to Video | | <ul><li>Single card</li></ul> | <ul><li>[text-to-video generation](/examples/stable-diffusion#text-to-video-generation)</li></ul> |
132132
| Image to Video | | <ul><li>Single card</li></ul> | <ul><li>[image-to-video generation](/examples/stable-diffusion#image-to-video-generation)</li></ul> |
133133
| i2vgen-xl | | <ul><li>Single card</li></ul> | <ul><li>[image-to-video generation](/examples/stable-diffusion#I2vgen-xl)</li></ul> |
134-
| Wan | | <ul><li>Single card</li></ul> | <ul><li>[text-to-video generation](/examples/stable-diffusion#text-to-video-with-wan-22)</li><li>[image-to-video generation](/examples/stable-diffusion#image-to-video-with-wan-22)</li></ul> |
134+
| Wan | | | <ul><li>[text-to-video generation](/examples/stable-diffusion#text-to-video-with-wan-22)</li><li>[image-to-video generation](/examples/stable-diffusion#image-to-video-with-wan-22)</li></ul> |
135135

136136
- PyTorch Image Models/TIMM:
137137

examples/stable-diffusion/README.md

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,9 +476,39 @@ python image_to_video_generation.py \
476476
--fps 24 \
477477
--num_frames 121 \
478478
--sdp_on_bf16 \
479-
--bf16
479+
--bf16
480480
```
481481

482+
#### Distributed Image-to-Video Wan 2.2 Inference
483+
484+
Wan models use classifier-free guidance (CFG), which processes both conditional and unconditional latents during denoising.
485+
With the `--use_distributed_cfg` option, we parallelize these 2 steps across a pair of HPU devices and then synchronize to apply guidance.
486+
While this mode uses 2 HPUs per unique generated video, it achieves almost 2x faster inference.
487+
488+
Here is an example of running Wan2.2 image-to-video model with 2 HPU devices in disributed CFG mode:
489+
490+
```bash
491+
PT_HPU_LAZY_MODE=1 \
492+
python ../gaudi_spawn.py --world_size 2 image_to_video_generation.py \
493+
--model_name_or_path "Wan-AI/Wan2.2-TI2V-5B-Diffusers" \
494+
--image_path "https://raw.githubusercontent.com/Wan-Video/Wan2.2/main/examples/i2v_input.JPG" \
495+
--video_save_dir ./wan2.2-output \
496+
--prompts "The cat removes the glasses from its eyes." \
497+
--use_habana \
498+
--use_hpu_graphs \
499+
--use_distributed_cfg \
500+
--height 1088 \
501+
--width 800 \
502+
--fps 24 \
503+
--num_frames 121 \
504+
--sdp_on_bf16 \
505+
--bf16
506+
```
507+
508+
> [!NOTE]
509+
> Distributed CFG mode requires even number of devices in the `world_size`.
510+
511+
482512
### Text-to-Video with Wan 2.2
483513
Wan2.2 is a comprehensive and open suite of video foundation models. Please refer to [Huggingface Wan2.2 doc](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B)
484514

@@ -502,6 +532,37 @@ python text_to_video_generation.py \
502532
--dtype bf16
503533
```
504534

535+
#### Distributed Text-to-Video Wan 2.2 Inference
536+
537+
Wan models use classifier-free guidance (CFG), which processes both conditional and unconditional latents during denoising.
538+
With the `--use_distributed_cfg` option, we parallelize these 2 steps across a pair of HPU devices and then synchronize to apply guidance.
539+
While this mode uses 2 HPUs per unique generated video, it achieves almost 2x faster inference.
540+
541+
Here is an example of running Wan2.2 text-to-video model with 2 HPU devices in disributed CFG mode:
542+
543+
```bash
544+
PT_HPU_LAZY_MODE=1 \
545+
python ../gaudi_spawn.py --world_size 2 text_to_video_generation.py \
546+
--model_name_or_path "Wan-AI/Wan2.2-TI2V-5B-Diffusers" \
547+
--prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
548+
--pipeline_type wan \
549+
--num_videos_per_prompt 1 \
550+
--use_habana \
551+
--use_hpu_graphs \
552+
--use_distributed_cfg \
553+
--height 704 \
554+
--width 1280 \
555+
--num_frames 121 \
556+
--num_inference_steps 50 \
557+
--guidance_scale 5.0 \
558+
--output_type mp4 \
559+
--dtype bf16
560+
```
561+
562+
> [!NOTE]
563+
> Distributed CFG mode requires even number of devices in the `world_size`.
564+
565+
505566
### Text-to-Video with CogvideoX
506567

507568
CogVideoX is an open-source version of the video generation model originating from QingYing, unveiled in https://huggingface.co/THUDM/CogVideoX-5b.

examples/stable-diffusion/image_to_video_generation.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import argparse
1717
import logging
18+
import os
1819
import sys
1920
from pathlib import Path
2021

@@ -206,6 +207,11 @@ def main():
206207
help="Allow pyTorch to use reduced precision in the SDPA math backend",
207208
)
208209
parser.add_argument("--num_frames", type=int, default=25, help="The number of video frames to generate.")
210+
parser.add_argument(
211+
"--use_distributed_cfg",
212+
action="store_true",
213+
help="Use distributed CFG (classifier-free guidance) across 2 devices for Wan pipeline. Requires even world size.",
214+
)
209215
parser.add_argument(
210216
"--profiling_warmup_steps",
211217
default=0,
@@ -291,7 +297,13 @@ def main():
291297
"sdp_on_bf16": args.sdp_on_bf16,
292298
}
293299

294-
set_seed(args.seed)
300+
# Set RNG seed
301+
seed_dist_offset = int(os.getenv("RANK", "0"))
302+
if args.use_distributed_cfg:
303+
# Same seed needed for a pair of workers with distributed CFG for SD3
304+
seed_dist_offset = seed_dist_offset // 2
305+
set_seed(args.seed + seed_dist_offset)
306+
295307
if args.bf16:
296308
kwargs["torch_dtype"] = torch.bfloat16
297309

@@ -377,6 +389,7 @@ def main():
377389
num_frames=args.num_frames,
378390
num_inference_steps=args.num_inference_steps,
379391
guidance_scale=5.0, # WAN I2V recommended guidance scale
392+
use_distributed_cfg=args.use_distributed_cfg,
380393
output_type=args.output_type,
381394
profiling_warmup_steps=args.profiling_warmup_steps,
382395
profiling_steps=args.profiling_steps,
@@ -419,25 +432,30 @@ def main():
419432
if args.output_type == "pil":
420433
video_save_dir = Path(args.video_save_dir)
421434
video_save_dir.mkdir(parents=True, exist_ok=True)
422-
logger.info(f"Saving video frames in {video_save_dir.resolve()}...")
423-
for i, frames in enumerate(outputs.frames):
424-
if args.gif:
425-
export_to_gif(frames, args.video_save_dir + "/gen_video_" + str(i).zfill(2) + ".gif")
426-
else:
427-
export_to_video(
428-
frames, args.video_save_dir + "/gen_video_" + str(i).zfill(2) + ".mp4", fps=args.fps
429-
)
430-
431-
if args.save_frames_as_images:
432-
for j, frame in enumerate(frames):
433-
frame.save(
434-
args.video_save_dir
435-
+ "/gen_video_"
436-
+ str(i).zfill(2)
437-
+ "_frame_"
438-
+ str(j).zfill(2)
439-
+ ".png"
435+
436+
rank = int(os.getenv("RANK", "0"))
437+
world_size = int(os.getenv("WORLD_SIZE", "1"))
438+
rank_ext = f"_rank{rank}" if world_size > 1 else ""
439+
skip_rank = False
440+
if args.use_distributed_cfg and world_size > 1:
441+
rank_ext += f"and{rank + 1}"
442+
skip_rank = rank % 2 == 1
443+
444+
if not skip_rank:
445+
logger.info(f"Saving video frames in {video_save_dir.resolve()}...")
446+
for i, frames in enumerate(outputs.frames):
447+
if args.gif:
448+
export_to_gif(frames, f"{args.video_save_dir}/gen_video_{str(i).zfill(2)}{rank_ext}.gif")
449+
else:
450+
export_to_video(
451+
frames, f"{args.video_save_dir}/gen_video_{str(i).zfill(2)}{rank_ext}.mp4", fps=args.fps
440452
)
453+
454+
if args.save_frames_as_images:
455+
for j, frame in enumerate(frames):
456+
frame.save(
457+
f"{args.video_save_dir}/gen_video_{str(i).zfill(2)}_frame_{str(j).zfill(2)}{rank_ext}.png"
458+
)
441459
else:
442460
logger.warning("--output_type should be equal to 'pil' to save frames in --video_save_dir.")
443461

examples/stable-diffusion/text_to_video_generation.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import argparse
1919
import logging
20+
import os
2021
import sys
2122
from pathlib import Path
2223

@@ -149,6 +150,11 @@ def main():
149150
choices=["bf16", "fp32", "autocast_bf16"],
150151
help="Which runtime dtype to perform generation in.",
151152
)
153+
parser.add_argument(
154+
"--use_distributed_cfg",
155+
action="store_true",
156+
help="Use distributed CFG (classifier-free guidance) across 2 devices for Wan pipeline. Requires even world size.",
157+
)
152158
args = parser.parse_args()
153159
# Setup logging
154160
logging.basicConfig(
@@ -183,6 +189,13 @@ def main():
183189
elif args.dtype == "fp32":
184190
kwargs["torch_dtype"] = torch.float32
185191

192+
# Set RNG seed
193+
seed_dist_offset = int(os.getenv("RANK", "0"))
194+
if args.use_distributed_cfg:
195+
# Same seed needed for a pair of workers with distributed CFG for SD3
196+
seed_dist_offset = seed_dist_offset // 2
197+
set_seed(args.seed + seed_dist_offset)
198+
186199
# Generate images
187200
if args.pipeline_type == "stable_diffusion":
188201
pipeline: GaudiTextToVideoSDPipeline = GaudiTextToVideoSDPipeline.from_pretrained(
@@ -199,7 +212,6 @@ def main():
199212
return None
200213

201214
if args.pipeline_type == "stable_diffusion":
202-
set_seed(args.seed)
203215
outputs = pipeline(
204216
prompt=args.prompts,
205217
num_videos_per_prompt=args.num_videos_per_prompt,
@@ -242,13 +254,13 @@ def main():
242254
filename = video_save_dir / "cogvideoX_out.mp4"
243255
export_to_video(video, str(filename.resolve()), fps=8)
244256
elif args.pipeline_type == "wan":
245-
set_seed(args.seed)
246257
outputs = pipeline(
247258
prompt=args.prompts,
248259
num_videos_per_prompt=args.num_videos_per_prompt,
249260
num_inference_steps=args.num_inference_steps,
250261
guidance_scale=args.guidance_scale,
251262
negative_prompt=args.negative_prompts,
263+
use_distributed_cfg=args.use_distributed_cfg,
252264
output_type="np" if args.output_type == "mp4" else args.output_type,
253265
**kwargs_call,
254266
)
@@ -262,11 +274,20 @@ def main():
262274
if args.output_type == "mp4":
263275
video_save_dir = Path(args.video_save_dir)
264276
video_save_dir.mkdir(parents=True, exist_ok=True)
265-
logger.info(f"Saving videos in {video_save_dir.resolve()}...")
266277

267-
for i, video in enumerate(outputs.frames):
268-
filename = video_save_dir / f"wan_video_{i + 1}.mp4"
269-
export_to_video(video, str(filename.resolve()), fps=16)
278+
rank = int(os.getenv("RANK", "0"))
279+
world_size = int(os.getenv("WORLD_SIZE", "1"))
280+
rank_ext = f"_rank{rank}" if world_size > 1 else ""
281+
skip_rank = False
282+
if args.use_distributed_cfg and world_size > 1:
283+
rank_ext += f"and{rank + 1}"
284+
skip_rank = rank % 2 == 1
285+
286+
if not skip_rank:
287+
logger.info(f"Saving videos in {video_save_dir.resolve()}...")
288+
for i, video in enumerate(outputs.frames):
289+
filename = video_save_dir / f"wan_video_{i + 1}{rank_ext}.mp4"
290+
export_to_video(video, str(filename.resolve()), fps=16)
270291
else:
271292
logger.warning("--output_type should be equal to 'mp4' to save videos in --video_save_dir.")
272293

0 commit comments

Comments
 (0)