Skip to content

[TRTLLM-11288][feat] Configurable warmup shapes for VisualGen#12107

Open
luyiyun1021 wants to merge 2 commits intoNVIDIA:mainfrom
luyiyun1021:fix/trtllm-11288-visualgen-warmup-config
Open

[TRTLLM-11288][feat] Configurable warmup shapes for VisualGen#12107
luyiyun1021 wants to merge 2 commits intoNVIDIA:mainfrom
luyiyun1021:fix/trtllm-11288-visualgen-warmup-config

Conversation

@luyiyun1021
Copy link
Collaborator

@luyiyun1021 luyiyun1021 commented Mar 11, 2026

@coderabbitai summary

Description

Make VisualGen warmup shapes configurable via a new WarmupConfig sub-config, replacing hardcoded warmup shapes in model pipeline subclasses.

Problem: Warmup shapes (resolution + frame count) were hardcoded in each pipeline class (e.g., [(480, 832, 33), (480, 832, 81), (720, 1280, 81)] for Wan). Users could not configure which shapes to pre-compile at startup. Requesting an un-warmed shape triggers torch.compile recompilation with seconds of delay.

Solution:

  • New WarmupConfig sub-config with resolutions and num_frames fields, combined via Cartesian product at warmup time
  • BasePipeline.resolve_warmup_plan() to resolve user config vs model defaults (user > model-default > empty)
  • BasePipeline.validate_shape() with model constraint declarations (resolution_multiple_of, is_valid_frame_count)
  • Request-level shape validation in DiffusionExecutor.process_request() — fail-fast for invalid shapes, warning for un-warmed shapes
  • Replace silent shape correction in I2V forward() with explicit validate_shape() error

User-facing YAML config:

warmup:
  resolutions:
    - [480, 832]
    - [720, 1280]
  num_frames: [33, 81]

Design decisions:

  • WarmupConfig as sub-config: follows the same pattern as existing 6 sub-configs (torch_compile, cuda_graph, etc.)
  • Cartesian product: resolutions × num_frames — simpler to configure than explicit (h, w, f) triples
  • OOM fail-fast: VisualGen warmup shapes are user/model declared, OOM means GPU cannot serve this config
  • warmup_steps not exposed: stays as model-internal heuristic (default_warmup_steps)

Backward compatibility:

  • Default behavior unchanged when warmup is not configured
  • I2V forward() now raises ValueError on invalid shapes instead of silently rounding (intentional)

Test Coverage

  • tests/unittest/_torch/visual_gen/test_warmup.py28 new unit tests (CPU only):
    • TestWarmupConfig (6): config construction, defaults, empty lists, strict validation
    • TestVisualGenArgsWarmup (4): dict/YAML construction, pickle serialization
    • TestResolveWarmupPlan (10): Cartesian product, priority chain, empty lists, invalid values, Flux constraints
    • TestValidateShape (7): resolution/frame validation, edge cases, Flux single-frame constraint
    • TestWarmupExecution (3): _warmed_up_shapes tracking, empty shapes skip
    • TestRequestValidation (3): warmed/un-warmed/invalid shape behavior
  • tests/unittest/_torch/visual_gen/test_visual_gen_args.py — 1 new test: WarmupConfig strict validation

Existing CI tests are not affected: all use skip_warmup=True and do not reference removed APIs.

PR Checklist

  • PR description clearly explains what and why.

  • PR Follows TRT-LLM CODING GUIDELINES.

  • Test cases are provided for new code paths.

  • Any new dependencies have been scanned for license and vulnerabilities.

  • CODEOWNERS updated if ownership changes.

  • Documentation updated as needed.

  • Update tava architecture diagram if significant design change.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Add WarmupConfig sub-config to allow users to specify which
(height, width, num_frames) shapes to warmup at startup via YAML.
Warmup shapes are the Cartesian product of resolutions x num_frames.

- New WarmupConfig with resolutions and num_frames fields
- BasePipeline.resolve_warmup_plan() with user > model-default priority
- BasePipeline.validate_shape() with model constraint declarations
- Migrate all 4 pipelines (Wan T2V/I2V, Flux, Flux2) to new interface
- Request-level validation in executor (fail-fast + warmup warning)
- Replace silent shape correction in I2V forward() with explicit error
- 28 new unit tests in test_warmup.py

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
@luyiyun1021 luyiyun1021 requested a review from a team as a code owner March 11, 2026 08:45
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 11, 2026

📝 Walkthrough

Walkthrough

This PR refactors the warmup configuration system across visual generation models from a hard-coded approach to a modular, configurable design. It introduces a new WarmupConfig model, replaces common_warmup_shapes with configurable methods (default_warmup_resolutions, default_warmup_num_frames, is_valid_frame_count), updates _run_warmup signatures to accept explicit parameters, and adds shape validation and warmup plan resolution logic.

Changes

Cohort / File(s) Summary
Configuration & Model Setup
tensorrt_llm/_torch/visual_gen/config.py
Introduces WarmupConfig model and adds warmup fields to VisualGenArgs and DiffusionModelConfig. Updates from_pretrained to read and propagate warmup configuration during model loading.
Base Pipeline & Integration
tensorrt_llm/_torch/visual_gen/pipeline.py
Replaces common_warmup_shapes with new methods: default_warmup_resolutions, default_warmup_num_frames, default_warmup_steps, resolution_multiple_of, is_valid_frame_count, validate_shape, and resolve_warmup_plan. Introduces _warmed_up_shapes tracking and updates warmup execution flow.
Executor Integration
tensorrt_llm/_torch/visual_gen/executor.py
Updates process_request to call validate_shape validation. Changes warmup-check condition from common_warmup_shapes to _warmed_up_shapes and updates logging messages.
Flux Pipeline Implementations
tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py, tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py
Replaces common_warmup_shapes property with default_warmup_resolutions, default_warmup_num_frames, and is_valid_frame_count methods. Updates _run_warmup signature from accepting warmup_steps to accepting explicit height, width, num_frames, and steps parameters.
Wan Pipeline Implementations
tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py, tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py
Refactors warmup configuration with new methods: default_warmup_resolutions, default_warmup_num_frames, default_warmup_steps, resolution_multiple_of, and is_valid_frame_count. Updates _run_warmup to accept explicit parameters. Integrates validate_shape for shape validation in forward pass (pipeline_wan_i2v.py).
Test Coverage
tests/unittest/_torch/visual_gen/test_visual_gen_args.py, tests/unittest/_torch/visual_gen/test_warmup.py
Adds WarmupConfig validation test to existing test file. Introduces comprehensive new test module covering WarmupConfig integration, warmup plan resolution, shape validation, warmup execution, and request validation scenarios.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 36.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically summarizes the main change: making VisualGen warmup shapes configurable via a new feature.
Description check ✅ Passed PR description comprehensively explains the problem, solution, design decisions, backward compatibility, and test coverage with specific examples and file changes.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/visual_gen/executor.py`:
- Around line 174-187: Move the call to validate_shape into the existing request
error-handling path so shape validation exceptions are caught and converted into
a DiffusionResponse error instead of bubbling out of process_request;
specifically, in process_request, call self.pipeline.validate_shape(req.height,
req.width, req.num_frames) inside the try block (the same block that constructs
and returns DiffusionResponse on error) rather than before it, ensuring any
raised errors are handled by the existing except/error response logic and
preserving the warmed-up-shapes warning logic around the now-validated request.

In `@tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py`:
- Around line 101-110: The pipeline currently only enforces frame-count via
is_valid_frame_count but must also enforce the FLUX.1 resolution multiple:
update validation (e.g., add or extend a method used by validate_shape) to check
that requested height and width are multiples of (self.vae_scale_factor * 2) and
fail validation otherwise; also update default_warmup_resolutions if needed to
remain valid under that rule. Reference the existing symbols
default_warmup_resolutions, default_warmup_num_frames, is_valid_frame_count,
validate_shape, and self.vae_scale_factor when implementing the check so
non-multiple resolutions are rejected early instead of being silently quantized.

In `@tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py`:
- Around line 169-178: The pipeline currently only declares frame-count
constraints (default_warmup_num_frames / is_valid_frame_count) but not spatial
constraints, so validate_shape() allows inputs that will be floored by
_prepare_latents()/ _prepare_latent_ids; add a spatial-validation hook on the
class (e.g., implement is_valid_spatial_resolution(self, height, width) or
override validate_shape to check spatial constraints) that returns True only for
resolutions that align with the packed latent grid (match entries in
default_warmup_resolutions or are exact multiples of the latent packing stride
used by _prepare_latents/_prepare_latent_ids), so requests like 1025x1024 are
rejected rather than silently floored. Ensure the new method uses the same
latent packing stride/scale as _prepare_latents/_prepare_latent_ids and
integrate it with existing validate_shape() logic.

In `@tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py`:
- Line 456: The code currently calls self.validate_shape(height, width,
num_frames) but does not reject requests where last_image is True and num_frames
< 2; add an explicit guard before preparing latents so that if self.last_image
(or the caller-provided last_image flag) is set and num_frames < 2 you raise a
ValueError with a clear message; adjust the logic either in validate_shape or
immediately around the call site that leads into _prepare_latents() so that
_prepare_latents() never receives num_frames < 2 (it allocates num_frames - 2
intermediate frames).

In `@tensorrt_llm/_torch/visual_gen/pipeline.py`:
- Around line 397-402: The unconditional call to torch.cuda.synchronize() after
running _run_warmup causes errors on CPU-only builds; guard the call with a CUDA
availability check (e.g., torch.cuda.is_available() or torch.cuda.device_count()
> 0) so synchronization only runs when CUDA is present, leaving the warmup loop
and setting of _warmed_up_shapes unchanged.
- Around line 136-146: In validate_shape, add an explicit fast-fail that rejects
non-positive inputs before the model-specific checks: verify height > 0, width >
0, and num_frames > 0 and raise a ValueError (including the class name via
self.__class__.__name__) if any are non-positive; keep this check above the
resolution_multiple_of logic and the is_valid_frame_count call so negative
dimensions or frame counts cannot pass model-specific predicates.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: acd47242-f370-49c3-ab58-92769cdde86c

📥 Commits

Reviewing files that changed from the base of the PR and between 298b6c8 and e6c9546.

📒 Files selected for processing (9)
  • tensorrt_llm/_torch/visual_gen/config.py
  • tensorrt_llm/_torch/visual_gen/executor.py
  • tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py
  • tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py
  • tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py
  • tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py
  • tensorrt_llm/_torch/visual_gen/pipeline.py
  • tests/unittest/_torch/visual_gen/test_visual_gen_args.py
  • tests/unittest/_torch/visual_gen/test_warmup.py

Comment on lines 174 to 187
def process_request(self, req: DiffusionRequest):
"""Process a single request."""
self.pipeline.validate_shape(req.height, req.width, req.num_frames)

if (
self.pipeline.common_warmup_shapes
and (req.height, req.width, req.num_frames) not in self.pipeline.common_warmup_shapes
self.pipeline._warmed_up_shapes
and (req.height, req.width, req.num_frames) not in self.pipeline._warmed_up_shapes
):
logger.warning(
f"Requested shape (height={req.height}, width={req.width}, num_frames={req.num_frames}) "
f"was not warmed up. First request with this shape will be slower due to "
"torch.compile recompilation or CUDA graph capture."
f"Warmed-up shapes: {self.pipeline.common_warmup_shapes}"
f"Requested shape ({req.height}x{req.width}, {req.num_frames} frames) "
f"was not warmed up. First request will be slower due to "
f"torch.compile recompilation. "
f"Warmed-up shapes: {self.pipeline._warmed_up_shapes}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Keep the new shape validation inside the existing error path.

validate_shape() now runs before the try, so a bad request raises out of process_request(), skips DiffusionResponse(error_msg=...), and can take down the worker instead of failing just that request.

Proposed fix
     def process_request(self, req: DiffusionRequest):
         """Process a single request."""
-        self.pipeline.validate_shape(req.height, req.width, req.num_frames)
-
-        if (
-            self.pipeline._warmed_up_shapes
-            and (req.height, req.width, req.num_frames) not in self.pipeline._warmed_up_shapes
-        ):
-            logger.warning(
-                f"Requested shape ({req.height}x{req.width}, {req.num_frames} frames) "
-                f"was not warmed up. First request will be slower due to "
-                f"torch.compile recompilation. "
-                f"Warmed-up shapes: {self.pipeline._warmed_up_shapes}"
-            )
         try:
+            self.pipeline.validate_shape(req.height, req.width, req.num_frames)
+
+            if (
+                self.pipeline._warmed_up_shapes
+                and (req.height, req.width, req.num_frames) not in self.pipeline._warmed_up_shapes
+            ):
+                logger.warning(
+                    f"Requested shape ({req.height}x{req.width}, {req.num_frames} frames) "
+                    f"was not warmed up. First request will be slower due to "
+                    f"torch.compile recompilation. "
+                    f"Warmed-up shapes: {self.pipeline._warmed_up_shapes}"
+                )
+
             output = self.pipeline.infer(req)
             if self.rank == 0:
                 self.response_queue.put(DiffusionResponse(request_id=req.request_id, output=output))
         except Exception as e:
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/executor.py` around lines 174 - 187, Move the
call to validate_shape into the existing request error-handling path so shape
validation exceptions are caught and converted into a DiffusionResponse error
instead of bubbling out of process_request; specifically, in process_request,
call self.pipeline.validate_shape(req.height, req.width, req.num_frames) inside
the try block (the same block that constructs and returns DiffusionResponse on
error) rather than before it, ensuring any raised errors are handled by the
existing except/error response logic and preserving the warmed-up-shapes warning
logic around the now-validated request.

Comment on lines 101 to +110
@property
def common_warmup_shapes(self) -> list:
"""Return list of common warmup shapes (height, width, num_frames)."""
return [(1024, 1024, 1)]
def default_warmup_resolutions(self):
return [(1024, 1024)]

@property
def default_warmup_num_frames(self):
return [1]

def is_valid_frame_count(self, num_frames):
return num_frames == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add the FLUX.1 resolution multiple override.

This pipeline now participates in validate_shape(), but only the frame-count constraint is declared. The latent pack/unpack path still quantizes height and width to self.vae_scale_factor * 2, so non-multiple resolutions can pass validation and produce a smaller image than requested instead of failing fast.

Proposed fix
     `@property`
     def default_warmup_num_frames(self):
         return [1]

+    `@property`
+    def resolution_multiple_of(self):
+        multiple = self.vae_scale_factor * 2
+        return (multiple, multiple)
+
     def is_valid_frame_count(self, num_frames):
         return num_frames == 1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py` around lines 101
- 110, The pipeline currently only enforces frame-count via is_valid_frame_count
but must also enforce the FLUX.1 resolution multiple: update validation (e.g.,
add or extend a method used by validate_shape) to check that requested height
and width are multiples of (self.vae_scale_factor * 2) and fail validation
otherwise; also update default_warmup_resolutions if needed to remain valid
under that rule. Reference the existing symbols default_warmup_resolutions,
default_warmup_num_frames, is_valid_frame_count, validate_shape, and
self.vae_scale_factor when implementing the check so non-multiple resolutions
are rejected early instead of being silently quantized.

Comment on lines 169 to +178
@property
def common_warmup_shapes(self) -> list:
"""Return list of common warmup shapes (height, width, num_frames)."""
return [(1024, 1024, 1)]
def default_warmup_resolutions(self):
return [(1024, 1024)]

@property
def default_warmup_num_frames(self):
return [1]

def is_valid_frame_count(self, num_frames):
return num_frames == 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Declare FLUX.2's spatial constraint in the new validation hooks.

validate_shape() now relies on subclass-declared constraints, but this pipeline only overrides frame count. _prepare_latents() / _prepare_latent_ids() still floor height and width to the packed latent grid, so a request like 1025x1024 will pass validation and effectively run at 1024x1024 instead of failing fast.

Proposed fix
     `@property`
     def default_warmup_num_frames(self):
         return [1]

+    `@property`
+    def resolution_multiple_of(self):
+        multiple = self.vae_scale_factor * 2
+        return (multiple, multiple)
+
     def is_valid_frame_count(self, num_frames):
         return num_frames == 1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py` around lines
169 - 178, The pipeline currently only declares frame-count constraints
(default_warmup_num_frames / is_valid_frame_count) but not spatial constraints,
so validate_shape() allows inputs that will be floored by _prepare_latents()/
_prepare_latent_ids; add a spatial-validation hook on the class (e.g., implement
is_valid_spatial_resolution(self, height, width) or override validate_shape to
check spatial constraints) that returns True only for resolutions that align
with the packed latent grid (match entries in default_warmup_resolutions or are
exact multiples of the latent packing stride used by
_prepare_latents/_prepare_latent_ids), so requests like 1025x1024 are rejected
rather than silently floored. Ensure the new method uses the same latent packing
stride/scale as _prepare_latents/_prepare_latent_ids and integrate it with
existing validate_shape() logic.

f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})."
)
height, width = calc_height, calc_width
self.validate_shape(height, width, num_frames)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject last_image requests with fewer than 2 frames.

This only validates the lattice constraints. When last_image is set, _prepare_latents() later allocates num_frames - 2 intermediate frames, so num_frames == 1 still gets through here and then fails with a negative-dimension error instead of the intended fail-fast ValueError.

💡 Proposed fix
         self.validate_shape(height, width, num_frames)
+        if last_image is not None and num_frames < 2:
+            raise ValueError(
+                "Wan I2V interpolation requires num_frames >= 2 when last_image is provided."
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
self.validate_shape(height, width, num_frames)
self.validate_shape(height, width, num_frames)
if last_image is not None and num_frames < 2:
raise ValueError(
"Wan I2V interpolation requires num_frames >= 2 when last_image is provided."
)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py` at line 456,
The code currently calls self.validate_shape(height, width, num_frames) but does
not reject requests where last_image is True and num_frames < 2; add an explicit
guard before preparing latents so that if self.last_image (or the
caller-provided last_image flag) is set and num_frames < 2 you raise a
ValueError with a clear message; adjust the logic either in validate_shape or
immediately around the call site that leads into _prepare_latents() so that
_prepare_latents() never receives num_frames < 2 (it allocates num_frames - 2
intermediate frames).

Comment on lines +136 to +146
def validate_shape(self, height: int, width: int, num_frames: int) -> None:
"""Validate shape against model constraints. Raises ValueError."""
h_mul, w_mul = self.resolution_multiple_of
if h_mul > 1 or w_mul > 1:
if height % h_mul != 0 or width % w_mul != 0:
raise ValueError(
f"Resolution ({height}x{width}) must be multiples of "
f"({h_mul}x{w_mul}) for {self.__class__.__name__}."
)
if not self.is_valid_frame_count(num_frames):
raise ValueError(f"Invalid num_frames={num_frames} for {self.__class__.__name__}.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Reject non-positive dimensions before the model-specific checks.

The new fail-fast path still accepts impossible shapes like height <= 0, width <= 0, and some negative frame counts. For example, WAN's (num_frames - 1) % 4 == 0 predicate accepts -3, so those requests now get past validation and fail much later in latent allocation.

Proposed fix
     def validate_shape(self, height: int, width: int, num_frames: int) -> None:
         """Validate shape against model constraints. Raises ValueError."""
+        if height <= 0 or width <= 0:
+            raise ValueError(f"Resolution must be positive, got ({height}x{width}).")
+        if num_frames <= 0:
+            raise ValueError(f"num_frames must be positive, got {num_frames}.")
+
         h_mul, w_mul = self.resolution_multiple_of
         if h_mul > 1 or w_mul > 1:
             if height % h_mul != 0 or width % w_mul != 0:
                 raise ValueError(
                     f"Resolution ({height}x{width}) must be multiples of "
🧰 Tools
🪛 Ruff (0.15.5)

[warning] 141-144: Avoid specifying long messages outside the exception class

(TRY003)


[warning] 146-146: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/pipeline.py` around lines 136 - 146, In
validate_shape, add an explicit fast-fail that rejects non-positive inputs
before the model-specific checks: verify height > 0, width > 0, and num_frames >
0 and raise a ValueError (including the class name via self.__class__.__name__)
if any are non-positive; keep this check above the resolution_multiple_of logic
and the is_valid_frame_count call so negative dimensions or frame counts cannot
pass model-specific predicates.

Comment on lines +397 to +402
for height, width, num_frames in shapes:
logger.info(f"Warmup: {height}x{width}, {num_frames} frames, {steps} steps")
self._run_warmup(height, width, num_frames, steps)
torch.cuda.synchronize()

torch.cuda.synchronize()
self._warmed_up_shapes = set(tuple(s) for s in shapes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard torch.cuda.synchronize() for CPU warmup paths.

This runs unconditionally now, but the new warmup tests are CPU-only. On a CPU-only build or host, torch.cuda.synchronize() raises even though the warmup loop itself is otherwise backend-agnostic.

Proposed fix
         for height, width, num_frames in shapes:
             logger.info(f"Warmup: {height}x{width}, {num_frames} frames, {steps} steps")
             self._run_warmup(height, width, num_frames, steps)
-            torch.cuda.synchronize()
+            if torch.cuda.is_available():
+                torch.cuda.synchronize()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for height, width, num_frames in shapes:
logger.info(f"Warmup: {height}x{width}, {num_frames} frames, {steps} steps")
self._run_warmup(height, width, num_frames, steps)
torch.cuda.synchronize()
torch.cuda.synchronize()
self._warmed_up_shapes = set(tuple(s) for s in shapes)
for height, width, num_frames in shapes:
logger.info(f"Warmup: {height}x{width}, {num_frames} frames, {steps} steps")
self._run_warmup(height, width, num_frames, steps)
if torch.cuda.is_available():
torch.cuda.synchronize()
self._warmed_up_shapes = set(tuple(s) for s in shapes)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/visual_gen/pipeline.py` around lines 397 - 402, The
unconditional call to torch.cuda.synchronize() after running _run_warmup causes
errors on CPU-only builds; guard the call with a CUDA availability check (e.g.,
torch.cuda.is_available() or torch.cuda.device_count() > 0) so synchronization
only runs when CUDA is present, leaving the warmup loop and setting of
_warmed_up_shapes unchanged.

…ive check

- Move validate_shape() inside executor try/except so invalid shape
  errors are returned as DiffusionResponse instead of crashing worker
- Add non-positive dimension check in validate_shape() (height/width/
  num_frames must be > 0)
- Revert Flux resolution_multiple_of (out of scope for this PR)

Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
@luyiyun1021 luyiyun1021 requested a review from zhenhuaw-me March 11, 2026 09:12
@luyiyun1021
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38575 [ run ] triggered by Bot. Commit: c2f9050 Link to invocation

Copy link
Member

@zhenhuaw-me zhenhuaw-me left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NVShreyas for viz.

enable_cuda_graph: bool = False


class WarmupConfig(StrictBaseModel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe CompilationConfig?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we can comment to the CUDA Graph and the torch compile config to mentioned that the shapes are configured here.

"""Check if frame count is valid for this model. Subclass override."""
return True

def validate_shape(self, height: int, width: int, num_frames: int) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warmup is not a hard blocker for running inference. We may log warning and ignore the invalid ones.

@zhenhuaw-me zhenhuaw-me requested a review from NVShreyas March 11, 2026 09:38
@tensorrt-cicd
Copy link
Collaborator

PR_Github #38575 [ run ] completed with state SUCCESS. Commit: c2f9050
/LLM/main/L0_MergeRequest_PR pipeline #29914 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Comment on lines +226 to +228
More warmup shapes = slower startup, but lower risk of torch.compile
recompilation delays on first requests. Fewer shapes = faster startup,
but first request with an un-warmed shape triggers recompilation.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should be careful allowing users to add a lot of warmup shapes as the torch compile cache is limited

torch._dynamo.config.cache_size_limit = 128

Might be good to check how many shapes can fit before it fills up the cache

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants