-
Notifications
You must be signed in to change notification settings - Fork 395
[Model] support Ltx2 text-to-video image-to-video #841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d9fa3ff
93b5771
9bb9828
8cc912c
7138824
92ca6c6
eae3c9e
a6f0e8d
6bb0e3d
47705ae
36bdbab
3f4d125
df9686c
cb7093d
9318f1b
69ee1fc
af231b3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,12 @@ def parse_args() -> argparse.Namespace: | |
| parser.add_argument("--height", type=int, default=720, help="Video height.") | ||
| parser.add_argument("--width", type=int, default=1280, help="Video width.") | ||
| parser.add_argument("--num_frames", type=int, default=81, help="Number of frames (Wan default is 81).") | ||
| parser.add_argument( | ||
| "--frame_rate", | ||
| type=float, | ||
| default=None, | ||
| help="Optional generation frame rate (used by models like LTX2). Defaults to --fps.", | ||
| ) | ||
| parser.add_argument("--num_inference_steps", type=int, default=40, help="Sampling steps.") | ||
| parser.add_argument( | ||
| "--boundary_ratio", | ||
|
|
@@ -41,17 +47,6 @@ def parse_args() -> argparse.Namespace: | |
| parser.add_argument( | ||
| "--flow_shift", type=float, default=5.0, help="Scheduler flow_shift (5.0 for 720p, 12.0 for 480p)." | ||
| ) | ||
| parser.add_argument( | ||
| "--cache_backend", | ||
| type=str, | ||
| default=None, | ||
| choices=["cache_dit"], | ||
| help=( | ||
| "Cache backend to use for acceleration. " | ||
| "Options: 'cache_dit' (DBCache + SCM + TaylorSeer). " | ||
| "Default: None (no cache acceleration)." | ||
| ), | ||
| ) | ||
| parser.add_argument( | ||
| "--enable-cache-dit-summary", | ||
| action="store_true", | ||
|
|
@@ -115,12 +110,31 @@ def parse_args() -> argparse.Namespace: | |
| default=1, | ||
| help="Number of GPUs used for tensor parallelism (TP) inside the DiT.", | ||
| ) | ||
| parser.add_argument( | ||
| "--audio_sample_rate", | ||
| type=int, | ||
| default=24000, | ||
| help="Sample rate for audio output when saved (default: 24000 for LTX2).", | ||
| ) | ||
| parser.add_argument( | ||
| "--cache_backend", | ||
| type=str, | ||
| default=None, | ||
| choices=["cache_dit", "tea_cache"], | ||
| help=( | ||
| "Cache backend to use for acceleration. " | ||
| "Options: 'cache_dit' (DBCache + SCM + TaylorSeer), 'tea_cache' (Timestep Embedding Aware Cache). " | ||
| "Default: None (no cache acceleration)." | ||
| ), | ||
| ) | ||
|
|
||
| return parser.parse_args() | ||
|
|
||
|
|
||
| def main(): | ||
| args = parse_args() | ||
| generator = torch.Generator(device=current_omni_platform.device_type).manual_seed(args.seed) | ||
| frame_rate = args.frame_rate if args.frame_rate is not None else float(args.fps) | ||
|
|
||
| # Wan2.2 cache-dit tuning (from cache-dit examples and cache_alignment). | ||
| cache_config = None | ||
|
|
@@ -140,8 +154,7 @@ def main(): | |
| "scm_steps_mask_policy": None, # SCM mask policy: None (disabled), "slow", "medium", "fast", "ultra" | ||
| "scm_steps_policy": "dynamic", # SCM steps policy: "dynamic" or "static" | ||
| } | ||
| # Configure parallel settings (only SP is supported for Wan) | ||
| # Note: cfg_parallel and tensor_parallel are not implemented for Wan models | ||
| # Configure parallel settings | ||
| parallel_config = DiffusionParallelConfig( | ||
| ulysses_degree=args.ulysses_degree, | ||
| ring_degree=args.ring_degree, | ||
|
|
@@ -160,12 +173,12 @@ def main(): | |
| vae_use_tiling=args.vae_use_tiling, | ||
| boundary_ratio=args.boundary_ratio, | ||
| flow_shift=args.flow_shift, | ||
| cache_backend=args.cache_backend, | ||
| cache_config=cache_config, | ||
| enable_cache_dit_summary=args.enable_cache_dit_summary, | ||
| enable_cpu_offload=args.enable_cpu_offload, | ||
| parallel_config=parallel_config, | ||
| enforce_eager=args.enforce_eager, | ||
| cache_backend=args.cache_backend, | ||
| cache_config=cache_config, | ||
| ) | ||
|
|
||
| if profiler_enabled: | ||
|
|
@@ -179,7 +192,11 @@ def main(): | |
| print(f" Inference steps: {args.num_inference_steps}") | ||
| print(f" Frames: {args.num_frames}") | ||
| print( | ||
| f" Parallel configuration: ulysses_degree={args.ulysses_degree}, ring_degree={args.ring_degree}, cfg_parallel_size={args.cfg_parallel_size}, tensor_parallel_size={args.tensor_parallel_size}" | ||
| " Parallel configuration: " | ||
| f"ulysses_degree={args.ulysses_degree}, " | ||
| f"ring_degree={args.ring_degree}, " | ||
| f"cfg_parallel_size={args.cfg_parallel_size}, " | ||
| f"tensor_parallel_size={args.tensor_parallel_size}" | ||
| ) | ||
| print(f" Video size: {args.width}x{args.height}") | ||
| print(f"{'=' * 60}\n") | ||
|
|
@@ -198,6 +215,7 @@ def main(): | |
| guidance_scale_2=args.guidance_scale_high, | ||
| num_inference_steps=args.num_inference_steps, | ||
| num_frames=args.num_frames, | ||
| frame_rate=frame_rate, | ||
| ), | ||
| ) | ||
| generation_end = time.perf_counter() | ||
|
|
@@ -206,64 +224,184 @@ def main(): | |
| # Print profiling results | ||
| print(f"Total generation time: {generation_time:.4f} seconds ({generation_time * 1000:.2f} ms)") | ||
|
|
||
| # Extract video frames from OmniRequestOutput | ||
| if isinstance(frames, list) and len(frames) > 0: | ||
| first_item = frames[0] | ||
| audio = None | ||
| if isinstance(frames, list): | ||
| frames = frames[0] if frames else None | ||
|
|
||
| # Check if it's an OmniRequestOutput | ||
| if hasattr(first_item, "final_output_type"): | ||
| if first_item.final_output_type != "image": | ||
| raise ValueError( | ||
| f"Unexpected output type '{first_item.final_output_type}', expected 'image' for video generation." | ||
| ) | ||
|
|
||
| # Pipeline mode: extract from nested request_output | ||
| if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output: | ||
| if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0: | ||
| inner_output = first_item.request_output[0] | ||
| if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"): | ||
| frames = inner_output.images[0] if inner_output.images else None | ||
| if frames is None: | ||
| raise ValueError("No video frames found in output.") | ||
| # Diffusion mode: use direct images field | ||
| elif hasattr(first_item, "images") and first_item.images: | ||
| frames = first_item.images | ||
| if isinstance(frames, OmniRequestOutput): | ||
| if frames.final_output_type != "image": | ||
| raise ValueError( | ||
| f"Unexpected output type '{frames.final_output_type}', expected 'image' for video generation." | ||
| ) | ||
| if frames.multimodal_output and "audio" in frames.multimodal_output: | ||
| audio = frames.multimodal_output["audio"] | ||
| if frames.is_pipeline_output and frames.request_output is not None: | ||
| inner_output = frames.request_output | ||
| if isinstance(inner_output, list): | ||
| inner_output = inner_output[0] if inner_output else None | ||
| if isinstance(inner_output, OmniRequestOutput): | ||
| if inner_output.multimodal_output and "audio" in inner_output.multimodal_output: | ||
| audio = inner_output.multimodal_output["audio"] | ||
| frames = inner_output | ||
| if isinstance(frames, OmniRequestOutput): | ||
| if frames.images: | ||
| if len(frames.images) == 1 and isinstance(frames.images[0], tuple) and len(frames.images[0]) == 2: | ||
| frames, audio = frames.images[0] | ||
| elif len(frames.images) == 1 and isinstance(frames.images[0], dict): | ||
| audio = frames.images[0].get("audio") | ||
| frames = frames.images[0].get("frames") or frames.images[0].get("video") | ||
| else: | ||
| frames = frames.images | ||
| else: | ||
| raise ValueError("No video frames found in OmniRequestOutput.") | ||
|
|
||
| if isinstance(frames, list) and frames: | ||
| first_item = frames[0] | ||
| if isinstance(first_item, tuple) and len(first_item) == 2: | ||
| frames, audio = first_item | ||
| elif isinstance(first_item, dict): | ||
| audio = first_item.get("audio") | ||
| frames = first_item.get("frames") or first_item.get("video") | ||
| elif isinstance(first_item, list): | ||
| frames = first_item | ||
|
|
||
| if isinstance(frames, tuple) and len(frames) == 2: | ||
| frames, audio = frames | ||
| elif isinstance(frames, dict): | ||
| audio = frames.get("audio") | ||
| frames = frames.get("frames") or frames.get("video") | ||
|
|
||
| if frames is None: | ||
| raise ValueError("No video frames found in output.") | ||
|
Comment on lines
+227
to
+275
|
||
|
|
||
| output_path = Path(args.output) | ||
| output_path.parent.mkdir(parents=True, exist_ok=True) | ||
| try: | ||
| from diffusers.utils import export_to_video | ||
| except ImportError: | ||
| raise ImportError("diffusers is required for export_to_video.") | ||
|
|
||
| # frames may be np.ndarray (preferred) or torch.Tensor | ||
| def _normalize_frame(frame): | ||
| if isinstance(frame, torch.Tensor): | ||
| frame_tensor = frame.detach().cpu() | ||
| if frame_tensor.dim() == 4 and frame_tensor.shape[0] == 1: | ||
| frame_tensor = frame_tensor[0] | ||
| if frame_tensor.dim() == 3 and frame_tensor.shape[0] in (3, 4): | ||
| frame_tensor = frame_tensor.permute(1, 2, 0) | ||
| if frame_tensor.is_floating_point(): | ||
| frame_tensor = frame_tensor.clamp(-1, 1) * 0.5 + 0.5 | ||
| return frame_tensor.float().numpy() | ||
| if isinstance(frame, np.ndarray): | ||
| frame_array = frame | ||
| if frame_array.ndim == 4 and frame_array.shape[0] == 1: | ||
| frame_array = frame_array[0] | ||
| if np.issubdtype(frame_array.dtype, np.integer): | ||
| frame_array = frame_array.astype(np.float32) / 255.0 | ||
| return frame_array | ||
| try: | ||
| from PIL import Image | ||
| except ImportError: | ||
| Image = None | ||
| if Image is not None and isinstance(frame, Image.Image): | ||
| return np.asarray(frame).astype(np.float32) / 255.0 | ||
| return frame | ||
|
|
||
| def _ensure_frame_list(video_array): | ||
| if isinstance(video_array, list): | ||
| if len(video_array) == 0: | ||
| return video_array | ||
| first_item = video_array[0] | ||
| if isinstance(first_item, np.ndarray): | ||
| if first_item.ndim == 5: | ||
| return list(first_item[0]) | ||
| if first_item.ndim == 4: | ||
| if len(video_array) == 1: | ||
| return list(first_item) | ||
| return list(first_item) | ||
| if first_item.ndim == 3: | ||
| return video_array | ||
| return video_array | ||
| if isinstance(video_array, np.ndarray): | ||
| if video_array.ndim == 5: | ||
| return list(video_array[0]) | ||
| if video_array.ndim == 4: | ||
| return list(video_array) | ||
| if video_array.ndim == 3: | ||
| return [video_array] | ||
| return video_array | ||
|
|
||
| # frames may be np.ndarray, torch.Tensor, or list of tensors/arrays/images | ||
| # export_to_video expects a list of frames with values in [0, 1] | ||
| if isinstance(frames, torch.Tensor): | ||
| video_tensor = frames.detach().cpu() | ||
| if video_tensor.dim() == 5: | ||
| # [B, C, F, H, W] or [B, F, H, W, C] | ||
| if video_tensor.shape[1] in (3, 4): | ||
| video_tensor = video_tensor[0].permute(1, 2, 3, 0) | ||
| else: | ||
| video_tensor = video_tensor[0] | ||
| elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): | ||
| video_tensor = video_tensor.permute(1, 2, 3, 0) | ||
| # If float, assume [-1,1] and normalize to [0,1] | ||
| if video_tensor.is_floating_point(): | ||
| video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 | ||
| video_array = video_tensor.float().numpy() | ||
| else: | ||
| elif isinstance(frames, np.ndarray): | ||
| video_array = frames | ||
| if hasattr(video_array, "shape") and video_array.ndim == 5: | ||
| if video_array.ndim == 5: | ||
| video_array = video_array[0] | ||
| if np.issubdtype(video_array.dtype, np.integer): | ||
| video_array = video_array.astype(np.float32) / 255.0 | ||
| elif isinstance(frames, list): | ||
| if len(frames) == 0: | ||
| raise ValueError("No video frames found in output.") | ||
| video_array = [_normalize_frame(frame) for frame in frames] | ||
| else: | ||
| video_array = frames | ||
|
|
||
| video_array = _ensure_frame_list(video_array) | ||
|
|
||
| use_ltx2_export = False | ||
| if args.model and "ltx" in str(args.model).lower(): | ||
| use_ltx2_export = True | ||
| if audio is not None: | ||
| use_ltx2_export = True | ||
|
|
||
| # Convert 4D array (frames, H, W, C) to list of frames for export_to_video | ||
| if isinstance(video_array, np.ndarray) and video_array.ndim == 4: | ||
| video_array = list(video_array) | ||
| if use_ltx2_export: | ||
| try: | ||
| from diffusers.pipelines.ltx2.export_utils import encode_video | ||
| except ImportError: | ||
| raise ImportError("diffusers is required for LTX2 encode_video.") | ||
|
|
||
| export_to_video(video_array, str(output_path), fps=args.fps) | ||
| if isinstance(video_array, list): | ||
| frames_np = np.stack(video_array, axis=0) | ||
| elif isinstance(video_array, np.ndarray): | ||
| frames_np = video_array | ||
| else: | ||
| frames_np = np.asarray(video_array) | ||
|
|
||
| frames_u8 = (frames_np * 255).round().clip(0, 255).astype("uint8") | ||
| video_tensor = torch.from_numpy(frames_u8) | ||
|
|
||
| audio_out = None | ||
| if audio is not None: | ||
| if isinstance(audio, list): | ||
| audio = audio[0] if audio else None | ||
| if isinstance(audio, np.ndarray): | ||
| audio = torch.from_numpy(audio) | ||
| if isinstance(audio, torch.Tensor): | ||
| audio_out = audio | ||
| if audio_out.dim() > 1: | ||
| audio_out = audio_out[0] | ||
| audio_out = audio_out.float().cpu() | ||
|
|
||
| encode_video( | ||
| video_tensor, | ||
| fps=args.fps, | ||
| audio=audio_out, | ||
| audio_sample_rate=args.audio_sample_rate if audio_out is not None else None, | ||
| output_path=str(output_path), | ||
| ) | ||
| else: | ||
| export_to_video(video_array, str(output_path), fps=args.fps) | ||
| print(f"Saved generated video to {output_path}") | ||
|
|
||
| if profiler_enabled: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about other inference examples