Skip to content

Commit 32da831

Browse files
miaojinchuijuanzhregisss
authored
Add Wan2.2 support (#2231)
Signed-off-by: Jincheng Miao <jincheng.miao@intel.com> Co-authored-by: Zhou, Huijuan <huijuan.zhou@intel.com> Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
1 parent 568f643 commit 32da831

File tree

13 files changed

+2118
-3
lines changed

13 files changed

+2118
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ The following model architectures, tasks and device distributions have been vali
306306
| Text to Video | | <li>Single card</li> | <li>[text-to-video generation](/examples/stable-diffusion#text-to-video-generation)</li> |
307307
| Image to Video | | <li>Single card</li> | <li>[image-to-video generation](/examples/stable-diffusion#image-to-video-generation)</li> |
308308
| i2vgen-xl | | <li>Single card</li> | <li>[image-to-video generation](/examples/stable-diffusion#I2vgen-xl)</li> |
309+
| Wan | | <li>Single card</li> | <li>[text-to-video generation](/examples/stable-diffusion#text-to-video-with-wan-22)</li><li>[image-to-video generation](/examples/stable-diffusion#image-to-video-with-wan-22)</li> |
309310

310311
### PyTorch Image Models/TIMM:
311312

docs/source/index.mdx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ In the tables below, ✅ means single-card, multi-card and DeepSpeed have all be
131131
| Text to Video | | <li>Single card</li> | <li>[text-to-video generation](/examples/stable-diffusion#text-to-video-generation)</li> |
132132
| Image to Video | | <li>Single card</li> | <li>[image-to-video generation](/examples/stable-diffusion#image-to-video-generation)</li> |
133133
| i2vgen-xl | | <li>Single card</li> | <li>[image-to-video generation](/examples/stable-diffusion#I2vgen-xl)</li> |
134+
| Wan | | <li>Single card</li> | <li>[text-to-video generation](/examples/stable-diffusion#text-to-video-with-wan-22)</li><li>[image-to-video generation](/examples/stable-diffusion#image-to-video-with-wan-22)</li> |
134135

135136
- PyTorch Image Models/TIMM:
136137

examples/stable-diffusion/README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,51 @@ python image_to_video_generation.py \
457457
--bf16
458458
```
459459

460+
### Image-to-Video with Wan 2.2
461+
Wan2.2 is a comprehensive and open suite of video foundation models. Please refer to [Huggingface Wan2.2 doc](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B)
462+
463+
Here is how to generate a video with one image and text prompt:
464+
465+
```bash
466+
PT_HPU_LAZY_MODE=1 \
467+
python image_to_video_generation.py \
468+
--model_name_or_path "Wan-AI/Wan2.2-TI2V-5B-Diffusers" \
469+
--image_path "https://raw.githubusercontent.com/Wan-Video/Wan2.2/main/examples/i2v_input.JPG" \
470+
--video_save_dir ./wan2.2-output \
471+
--prompts "The cat removes the glasses from its eyes." \
472+
--use_habana \
473+
--use_hpu_graphs \
474+
--height 1088 \
475+
--width 800 \
476+
--fps 24 \
477+
--num_frames 121 \
478+
--sdp_on_bf16 \
479+
--bf16
480+
```
481+
482+
### Text-to-Video with Wan 2.2
483+
Wan2.2 is a comprehensive and open suite of video foundation models. Please refer to [Huggingface Wan2.2 doc](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B)
484+
485+
Here is how to generate a video with text prompt:
486+
487+
```bash
488+
PT_HPU_LAZY_MODE=1 \
489+
python text_to_video_generation.py \
490+
--model_name_or_path "Wan-AI/Wan2.2-TI2V-5B-Diffusers" \
491+
--prompts "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
492+
--pipeline_type wan \
493+
--num_videos_per_prompt 1 \
494+
--use_habana \
495+
--use_hpu_graphs \
496+
--height 704 \
497+
--width 1280 \
498+
--num_frames 121 \
499+
--num_inference_steps 50 \
500+
--guidance_scale 5.0 \
501+
--output_type mp4 \
502+
--dtype bf16
503+
```
504+
460505
### Text-to-Video with CogvideoX
461506

462507
CogVideoX is an open-source version of the video generation model originating from QingYing, unveiled in https://huggingface.co/THUDM/CogVideoX-5b.

examples/stable-diffusion/image_to_video_generation.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
GaudiEulerDiscreteScheduler,
2727
GaudiI2VGenXLPipeline,
2828
GaudiStableVideoDiffusionPipeline,
29+
GaudiWanImageToVideoPipeline,
2930
)
31+
from optimum.habana.transformers.gaudi_configuration import GaudiConfig
3032
from optimum.habana.utils import set_seed
3133

3234

@@ -236,6 +238,18 @@ def main():
236238
is_i2v_model = any(model in args.model_name_or_path for model in i2v_models)
237239
cogvideo_models = ["cogvideo"]
238240
is_cogvideo_model = any(model in args.model_name_or_path.lower() for model in cogvideo_models)
241+
wan_i2v_models = ["Wan2.2"]
242+
is_wan_i2v_model = any(model in args.model_name_or_path for model in wan_i2v_models)
243+
244+
if is_wan_i2v_model:
245+
gaudi_config_kwargs = {"use_fused_adam": True, "use_fused_clip_norm": True}
246+
if args.bf16:
247+
gaudi_config_kwargs["use_torch_autocast"] = True
248+
249+
gaudi_config = GaudiConfig(**gaudi_config_kwargs)
250+
args.gaudi_config_name = gaudi_config
251+
logger.info(f"Gaudi Config: {gaudi_config}")
252+
239253
# Load input image(s)
240254
input = []
241255
logger.info("Input image(s):")
@@ -245,6 +259,11 @@ def main():
245259
image = load_image(image_path)
246260
if is_i2v_model:
247261
image = image.convert("RGB")
262+
elif is_wan_i2v_model:
263+
image = image.resize((args.height, args.width))
264+
# wan2.2 i2v pipeline only accepts 1 image
265+
input = image
266+
break
248267
else:
249268
image = image.resize((args.height, args.width))
250269
input.append(image)
@@ -342,6 +361,26 @@ def main():
342361
num_frames=args.num_frames,
343362
generator=generator,
344363
)
364+
elif is_wan_i2v_model:
365+
del kwargs["scheduler"] # WAN I2V uses its own scheduler
366+
pipeline = GaudiWanImageToVideoPipeline.from_pretrained(
367+
args.model_name_or_path,
368+
**kwargs,
369+
)
370+
outputs = pipeline(
371+
image=input,
372+
prompt=args.prompts,
373+
negative_prompt=args.negative_prompts,
374+
num_videos_per_prompt=args.num_videos_per_prompt,
375+
height=args.height,
376+
width=args.width,
377+
num_frames=args.num_frames,
378+
num_inference_steps=args.num_inference_steps,
379+
guidance_scale=5.0, # WAN I2V recommended guidance scale
380+
output_type=args.output_type,
381+
profiling_warmup_steps=args.profiling_warmup_steps,
382+
profiling_steps=args.profiling_steps,
383+
)
345384
else:
346385
pipeline = GaudiStableVideoDiffusionPipeline.from_pretrained(
347386
args.model_name_or_path,
@@ -385,7 +424,9 @@ def main():
385424
if args.gif:
386425
export_to_gif(frames, args.video_save_dir + "/gen_video_" + str(i).zfill(2) + ".gif")
387426
else:
388-
export_to_video(frames, args.video_save_dir + "/gen_video_" + str(i).zfill(2) + ".mp4", fps=7)
427+
export_to_video(
428+
frames, args.video_save_dir + "/gen_video_" + str(i).zfill(2) + ".mp4", fps=args.fps
429+
)
389430

390431
if args.save_frames_as_images:
391432
for j, frame in enumerate(frames):

examples/stable-diffusion/requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,5 @@ opencv-python
22
compel
33
sentencepiece
44
peft == 0.17.0
5+
ftfy
6+

examples/stable-diffusion/text_to_video_generation.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424
from diffusers.utils.export_utils import export_to_video
2525

26-
from optimum.habana.diffusers import GaudiCogVideoXPipeline, GaudiTextToVideoSDPipeline
26+
from optimum.habana.diffusers import GaudiCogVideoXPipeline, GaudiTextToVideoSDPipeline, GaudiWanPipeline
2727
from optimum.habana.transformers.gaudi_configuration import GaudiConfig
2828
from optimum.habana.utils import set_seed
2929

@@ -56,7 +56,7 @@ def main():
5656
"--pipeline_type",
5757
type=str,
5858
default="stable_diffusion",
59-
help="pipeline type:stable_diffusion or cogvideoX",
59+
help="pipeline type:stable_diffusion, cogvideoX or wan",
6060
)
6161
# Pipeline arguments
6262
parser.add_argument(
@@ -192,6 +192,8 @@ def main():
192192
pipeline: GaudiCogVideoXPipeline = GaudiCogVideoXPipeline.from_pretrained(args.model_name_or_path, **kwargs)
193193
pipeline.vae.enable_tiling()
194194
pipeline.vae.enable_slicing()
195+
elif args.pipeline_type == "wan":
196+
pipeline: GaudiWanPipeline = GaudiWanPipeline.from_pretrained(args.model_name_or_path, **kwargs)
195197
else:
196198
logger.error(f"unsupported pipeline type {args.pipeline_type}")
197199
return None
@@ -239,6 +241,34 @@ def main():
239241
video_save_dir.mkdir(parents=True, exist_ok=True)
240242
filename = video_save_dir / "cogvideoX_out.mp4"
241243
export_to_video(video, str(filename.resolve()), fps=8)
244+
elif args.pipeline_type == "wan":
245+
set_seed(args.seed)
246+
outputs = pipeline(
247+
prompt=args.prompts,
248+
num_videos_per_prompt=args.num_videos_per_prompt,
249+
num_inference_steps=args.num_inference_steps,
250+
guidance_scale=args.guidance_scale,
251+
negative_prompt=args.negative_prompts,
252+
output_type="np" if args.output_type == "mp4" else args.output_type,
253+
**kwargs_call,
254+
)
255+
256+
# Save the pipeline in the specified directory if not None
257+
if args.pipeline_save_dir is not None:
258+
pipeline.save_pretrained(args.pipeline_save_dir)
259+
260+
# Save videos in the specified directory if not None
261+
if args.video_save_dir is not None:
262+
if args.output_type == "mp4":
263+
video_save_dir = Path(args.video_save_dir)
264+
video_save_dir.mkdir(parents=True, exist_ok=True)
265+
logger.info(f"Saving videos in {video_save_dir.resolve()}...")
266+
267+
for i, video in enumerate(outputs.frames):
268+
filename = video_save_dir / f"wan_video_{i + 1}.mp4"
269+
export_to_video(video, str(filename.resolve()), fps=16)
270+
else:
271+
logger.warning("--output_type should be equal to 'mp4' to save videos in --video_save_dir.")
242272

243273

244274
if __name__ == "__main__":

optimum/habana/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from .pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_inpaint import GaudiStableDiffusionXLInpaintPipeline
2929
from .pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import GaudiStableVideoDiffusionPipeline
3030
from .pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import GaudiTextToVideoSDPipeline
31+
from .pipelines.wan.pipeline_wan import GaudiWanPipeline
32+
from .pipelines.wan.pipeline_wan_i2v import GaudiWanImageToVideoPipeline
3133
from .schedulers import (
3234
GaudiDDIMScheduler,
3335
GaudiEulerAncestralDiscreteScheduler,

optimum/habana/diffusers/models/attention_processor.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020
import torch.nn.functional as F
2121
from diffusers.models.attention_processor import Attention
22+
from diffusers.models.transformers.transformer_wan import WanAttention, _get_added_kv_projections, _get_qkv_projections
2223
from diffusers.utils import deprecate, logging
2324
from diffusers.utils.import_utils import is_xformers_available
2425
from torch import nn
@@ -535,4 +536,115 @@ def __call__(
535536
return hidden_states
536537

537538

539+
class GaudiWanAttnProcessor:
540+
r"""
541+
Adapted from: https://github.com/huggingface/diffusers/blob/v0.35.1/src/diffusers/models/transformers/transformer_wan.py#L67
542+
543+
This class copied from `WanAttnProcessor` and overrides methods to use Gaudi-specific implementations.
544+
Add a func _native_attention which uses FusedSDPA on Gaudi
545+
Use hpex.kernels.apply_rotary_pos_emb on Gaudi
546+
"""
547+
548+
_attention_backend = None
549+
550+
def __init__(self, is_training=False):
551+
if not hasattr(F, "scaled_dot_product_attention"):
552+
raise ImportError(
553+
"WanAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
554+
)
555+
self.is_training = is_training
556+
557+
def _native_attention(
558+
self,
559+
query: torch.Tensor,
560+
key: torch.Tensor,
561+
value: torch.Tensor,
562+
attn_mask: Optional[torch.Tensor] = None,
563+
dropout_p: float = 0.0,
564+
is_causal: bool = False,
565+
scale: Optional[float] = None,
566+
enable_gqa: bool = False,
567+
) -> torch.Tensor:
568+
# apply gaudi fused SDPA
569+
from habana_frameworks.torch.hpex.kernels import FusedSDPA
570+
571+
# Fast FSDPA is not supported in training mode
572+
fsdpa_mode = "None" if self.is_training else "fast"
573+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
574+
out = FusedSDPA.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, fsdpa_mode, None)
575+
out = out.permute(0, 2, 1, 3)
576+
return out
577+
578+
def __call__(
579+
self,
580+
attn: "WanAttention",
581+
hidden_states: torch.Tensor,
582+
encoder_hidden_states: Optional[torch.Tensor] = None,
583+
attention_mask: Optional[torch.Tensor] = None,
584+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
585+
) -> torch.Tensor:
586+
encoder_hidden_states_img = None
587+
if attn.add_k_proj is not None:
588+
# 512 is the context length of the text encoder, hardcoded for now
589+
image_context_length = encoder_hidden_states.shape[1] - 512
590+
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
591+
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
592+
593+
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
594+
595+
query = attn.norm_q(query)
596+
key = attn.norm_k(key)
597+
598+
query = query.unflatten(2, (attn.heads, -1))
599+
key = key.unflatten(2, (attn.heads, -1))
600+
value = value.unflatten(2, (attn.heads, -1))
601+
602+
if rotary_emb is not None:
603+
"""
604+
Wan's ROPE is pairwised, like this:
605+
def apply_rotary_emb(
606+
hidden_states: torch.Tensor,
607+
freqs_cos: torch.Tensor,
608+
freqs_sin: torch.Tensor,
609+
):
610+
x1, x2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
611+
cos = freqs_cos[..., 0::2]
612+
sin = freqs_sin[..., 1::2]
613+
out = torch.empty_like(hidden_states)
614+
out[..., 0::2] = x1 * cos - x2 * sin
615+
out[..., 1::2] = x1 * sin + x2 * cos
616+
return out.type_as(hidden_states)
617+
"""
618+
from habana_frameworks.torch.hpex.kernels import RotaryPosEmbeddingMode, apply_rotary_pos_emb
619+
620+
query = apply_rotary_pos_emb(query, *rotary_emb, None, 0, RotaryPosEmbeddingMode.PAIRWISE)
621+
key = apply_rotary_pos_emb(key, *rotary_emb, None, 0, RotaryPosEmbeddingMode.PAIRWISE)
622+
623+
# I2V task
624+
hidden_states_img = None
625+
if encoder_hidden_states_img is not None:
626+
key_img, value_img = _get_added_kv_projections(attn, encoder_hidden_states_img)
627+
key_img = attn.norm_added_k(key_img)
628+
629+
key_img = key_img.unflatten(2, (attn.heads, -1))
630+
value_img = value_img.unflatten(2, (attn.heads, -1))
631+
632+
hidden_states_img = self._native_attention(query, key_img, value_img, None, 0.0, False, None)
633+
634+
hidden_states_img = hidden_states_img.flatten(2, 3)
635+
hidden_states_img = hidden_states_img.type_as(query)
636+
637+
hidden_states = self._native_attention(query, key, value, attention_mask, 0.0, False, None)
638+
639+
hidden_states = hidden_states.flatten(2, 3)
640+
hidden_states = hidden_states.type_as(query)
641+
642+
if hidden_states_img is not None:
643+
hidden_states = hidden_states + hidden_states_img
644+
645+
hidden_states = attn.to_out[0](hidden_states)
646+
hidden_states = attn.to_out[1](hidden_states)
647+
return hidden_states
648+
649+
538650
AttentionProcessor = Union[AttnProcessor2_0,]

0 commit comments

Comments
 (0)