Skip to content

[None][refactor] parallel vae refactor#12123

Open
NVShreyas wants to merge 1 commit intoNVIDIA:mainfrom
NVShreyas:user/shreyasm/parallel-vae-refactor
Open

[None][refactor] parallel vae refactor#12123
NVShreyas wants to merge 1 commit intoNVIDIA:mainfrom
NVShreyas:user/shreyasm/parallel-vae-refactor

Conversation

@NVShreyas
Copy link
Collaborator

@NVShreyas NVShreyas commented Mar 11, 2026

Summary by CodeRabbit

Release Notes

  • Refactor
    • Restructured VAE parallelization architecture with updated interfaces and factory-based initialization approach.
    • Updated VAE adapter naming conventions and revised configuration patterns for parallelization.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
@NVShreyas NVShreyas requested a review from a team as a code owner March 11, 2026 20:36
@NVShreyas
Copy link
Collaborator Author

/bot run --disable-fail-fast

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 11, 2026

📝 Walkthrough

Walkthrough

This pull request refactors the VAE parallelization system by replacing the adapter-based architecture with a factory-based approach. It introduces new abstractions (ParallelVAEBase, SplitSpec, ParallelVAEFactory) and renames WanParallelVAEAdapter to ParallelVAE_Wan. The pipeline setup is updated to use the factory pattern for VAE wrapping, and the legacy setup_parallel_vae function is removed.

Changes

Cohort / File(s) Summary
VAE Parallel Interface & Base
tensorrt_llm/_torch/visual_gen/modules/vae/parallel_vae_interface.py
Introduces SplitSpec dataclass, ParallelVAEBase nn.Module with encode/decode delegation and hooks for subclasses, and ParallelVAEFactory for lazy-loaded VAE wrapper instantiation. Replaces old BaseParallelVAEAdapter with new architecture supporting process group management and tensor split/gather helpers.
VAE Module Exports
tensorrt_llm/_torch/visual_gen/modules/vae/__init__.py
Updates public API to export ParallelVAEBase, ParallelVAEFactory, and SplitSpec instead of BaseParallelVAEAdapter.
WAN Model Parallel VAE
tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py, tensorrt_llm/_torch/visual_gen/models/wan/__init__.py
Renames WanParallelVAEAdapter to ParallelVAE_Wan and refactors to inherit from ParallelVAEBase. Introduces make_spec() staticmethod, replaces chunk_dims/adj_groups with spec/\_adj_groups, adds \_encode_impl/\_decode_impl, and updates module replacement logic to use new spec structure.
WAN Pipeline Updates
tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py, tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py
Removes vae_adapter_class property and associated imports for WanParallelVAEAdapter.
Parallelism & Pipeline Setup
tensorrt_llm/_torch/visual_gen/parallelism.py, tensorrt_llm/_torch/visual_gen/pipeline.py
Removes setup_parallel_vae function; updates BasePipeline.setup_parallel_vae to use ParallelVAEFactory.from_vae() for VAE wrapping. Removes vae_adapter_class property and adapter-based initialization logic; updates logging and error messages to reflect factory-based approach.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~65 minutes

🚥 Pre-merge checks | ❌ 3

❌ Failed checks (2 warnings, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning PR description is entirely templated with no substantive content; required sections (Description, Test Coverage) are empty. Complete the Description section explaining the refactor rationale, scope, and API changes. Detail Test Coverage with specific tests validating the parallel VAE factory pattern and module parallelization hooks.
Title check ❓ Inconclusive Title is related to the changeset but lacks specificity and clarity about the main architectural changes. Enhance title to highlight the core refactor, e.g., 'Refactor parallel VAE to use factory pattern and base interface' or 'Replace adapter-based parallel VAE with factory-based architecture'.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
tensorrt_llm/_torch/visual_gen/modules/vae/parallel_vae_interface.py (1)

127-145: Consider documenting process group lifecycle.

The _build_adj_groups creates pairwise process groups that persist for the lifetime of the process. This is standard for distributed training, but a brief docstring note about cleanup responsibility (if any) would be helpful for maintainers.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/modules/vae/parallel_vae_interface.py` around
lines 127 - 145, Add a short note to the _build_adj_groups docstring documenting
that the created ProcessGroup objects (returned in adj_groups) persist for the
process lifetime and that callers are responsible for any explicit cleanup (if
needed) or should rely on the process teardown to release them; reference the
_build_adj_groups static method and the returned adj_groups/ProcessGroup objects
so maintainers know where lifecycle and cleanup responsibility applies.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/visual_gen/pipeline.py`:
- Around line 204-219: The new process group created by dist.new_group is leaked
when ParallelVAEFactory.from_vae raises ValueError; ensure the group is
destroyed on failure by calling dist.destroy_process_group(pg) (or the
appropriate destroy function) inside the except ValueError block before
returning, referencing the created pg and the call to
ParallelVAEFactory.from_vae (and self.vae) so the cleanup runs only on the
failure path.

---

Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/modules/vae/parallel_vae_interface.py`:
- Around line 127-145: Add a short note to the _build_adj_groups docstring
documenting that the created ProcessGroup objects (returned in adj_groups)
persist for the process lifetime and that callers are responsible for any
explicit cleanup (if needed) or should rely on the process teardown to release
them; reference the _build_adj_groups static method and the returned
adj_groups/ProcessGroup objects so maintainers know where lifecycle and cleanup
responsibility applies.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 64e3f104-2807-4a54-9bf2-4065f4f8946d

📥 Commits

Reviewing files that changed from the base of the PR and between ea4d4d1 and 6b673d0.

📒 Files selected for processing (8)
  • tensorrt_llm/_torch/visual_gen/models/wan/__init__.py
  • tensorrt_llm/_torch/visual_gen/models/wan/parallel_vae.py
  • tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py
  • tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py
  • tensorrt_llm/_torch/visual_gen/modules/vae/__init__.py
  • tensorrt_llm/_torch/visual_gen/modules/vae/parallel_vae_interface.py
  • tensorrt_llm/_torch/visual_gen/parallelism.py
  • tensorrt_llm/_torch/visual_gen/pipeline.py

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38631 [ run ] triggered by Bot. Commit: 6b673d0 Link to invocation

Copy link
Collaborator

@chang-l chang-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the effort, @NVShreyas. Just a heads-up — not sure if this refactor/design could help enable parallel vae to LTX-2 in the future..
https://github.com/Lightricks/LTX-2/blob/ae855f8538843825f9015a419cf4ba5edaf5eec2/packages/ltx-core/src/ltx_core/model/video_vae/video_vae.py#L802

@NVShreyas
Copy link
Collaborator Author

NVShreyas commented Mar 11, 2026

@chang-l I see LTX2 uses tiling. For that we would need parallel VAE using tile distribution which would need more changes. But the halo exchange and refactor should hold

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants