[TRTLLM-11288][feat] Configurable warmup shapes for VisualGen#12107
[TRTLLM-11288][feat] Configurable warmup shapes for VisualGen#12107luyiyun1021 wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Add WarmupConfig sub-config to allow users to specify which (height, width, num_frames) shapes to warmup at startup via YAML. Warmup shapes are the Cartesian product of resolutions x num_frames. - New WarmupConfig with resolutions and num_frames fields - BasePipeline.resolve_warmup_plan() with user > model-default priority - BasePipeline.validate_shape() with model constraint declarations - Migrate all 4 pipelines (Wan T2V/I2V, Flux, Flux2) to new interface - Request-level validation in executor (fail-fast + warmup warning) - Replace silent shape correction in I2V forward() with explicit error - 28 new unit tests in test_warmup.py Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
📝 WalkthroughWalkthroughThis PR refactors the warmup configuration system across visual generation models from a hard-coded approach to a modular, configurable design. It introduces a new WarmupConfig model, replaces common_warmup_shapes with configurable methods (default_warmup_resolutions, default_warmup_num_frames, is_valid_frame_count), updates _run_warmup signatures to accept explicit parameters, and adds shape validation and warmup plan resolution logic. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/visual_gen/executor.py`:
- Around line 174-187: Move the call to validate_shape into the existing request
error-handling path so shape validation exceptions are caught and converted into
a DiffusionResponse error instead of bubbling out of process_request;
specifically, in process_request, call self.pipeline.validate_shape(req.height,
req.width, req.num_frames) inside the try block (the same block that constructs
and returns DiffusionResponse on error) rather than before it, ensuring any
raised errors are handled by the existing except/error response logic and
preserving the warmed-up-shapes warning logic around the now-validated request.
In `@tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py`:
- Around line 101-110: The pipeline currently only enforces frame-count via
is_valid_frame_count but must also enforce the FLUX.1 resolution multiple:
update validation (e.g., add or extend a method used by validate_shape) to check
that requested height and width are multiples of (self.vae_scale_factor * 2) and
fail validation otherwise; also update default_warmup_resolutions if needed to
remain valid under that rule. Reference the existing symbols
default_warmup_resolutions, default_warmup_num_frames, is_valid_frame_count,
validate_shape, and self.vae_scale_factor when implementing the check so
non-multiple resolutions are rejected early instead of being silently quantized.
In `@tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py`:
- Around line 169-178: The pipeline currently only declares frame-count
constraints (default_warmup_num_frames / is_valid_frame_count) but not spatial
constraints, so validate_shape() allows inputs that will be floored by
_prepare_latents()/ _prepare_latent_ids; add a spatial-validation hook on the
class (e.g., implement is_valid_spatial_resolution(self, height, width) or
override validate_shape to check spatial constraints) that returns True only for
resolutions that align with the packed latent grid (match entries in
default_warmup_resolutions or are exact multiples of the latent packing stride
used by _prepare_latents/_prepare_latent_ids), so requests like 1025x1024 are
rejected rather than silently floored. Ensure the new method uses the same
latent packing stride/scale as _prepare_latents/_prepare_latent_ids and
integrate it with existing validate_shape() logic.
In `@tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py`:
- Line 456: The code currently calls self.validate_shape(height, width,
num_frames) but does not reject requests where last_image is True and num_frames
< 2; add an explicit guard before preparing latents so that if self.last_image
(or the caller-provided last_image flag) is set and num_frames < 2 you raise a
ValueError with a clear message; adjust the logic either in validate_shape or
immediately around the call site that leads into _prepare_latents() so that
_prepare_latents() never receives num_frames < 2 (it allocates num_frames - 2
intermediate frames).
In `@tensorrt_llm/_torch/visual_gen/pipeline.py`:
- Around line 397-402: The unconditional call to torch.cuda.synchronize() after
running _run_warmup causes errors on CPU-only builds; guard the call with a CUDA
availability check (e.g., torch.cuda.is_available() or torch.cuda.device_count()
> 0) so synchronization only runs when CUDA is present, leaving the warmup loop
and setting of _warmed_up_shapes unchanged.
- Around line 136-146: In validate_shape, add an explicit fast-fail that rejects
non-positive inputs before the model-specific checks: verify height > 0, width >
0, and num_frames > 0 and raise a ValueError (including the class name via
self.__class__.__name__) if any are non-positive; keep this check above the
resolution_multiple_of logic and the is_valid_frame_count call so negative
dimensions or frame counts cannot pass model-specific predicates.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: acd47242-f370-49c3-ab58-92769cdde86c
📒 Files selected for processing (9)
tensorrt_llm/_torch/visual_gen/config.pytensorrt_llm/_torch/visual_gen/executor.pytensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.pytensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.pytensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.pytensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.pytensorrt_llm/_torch/visual_gen/pipeline.pytests/unittest/_torch/visual_gen/test_visual_gen_args.pytests/unittest/_torch/visual_gen/test_warmup.py
| def process_request(self, req: DiffusionRequest): | ||
| """Process a single request.""" | ||
| self.pipeline.validate_shape(req.height, req.width, req.num_frames) | ||
|
|
||
| if ( | ||
| self.pipeline.common_warmup_shapes | ||
| and (req.height, req.width, req.num_frames) not in self.pipeline.common_warmup_shapes | ||
| self.pipeline._warmed_up_shapes | ||
| and (req.height, req.width, req.num_frames) not in self.pipeline._warmed_up_shapes | ||
| ): | ||
| logger.warning( | ||
| f"Requested shape (height={req.height}, width={req.width}, num_frames={req.num_frames}) " | ||
| f"was not warmed up. First request with this shape will be slower due to " | ||
| "torch.compile recompilation or CUDA graph capture." | ||
| f"Warmed-up shapes: {self.pipeline.common_warmup_shapes}" | ||
| f"Requested shape ({req.height}x{req.width}, {req.num_frames} frames) " | ||
| f"was not warmed up. First request will be slower due to " | ||
| f"torch.compile recompilation. " | ||
| f"Warmed-up shapes: {self.pipeline._warmed_up_shapes}" | ||
| ) |
There was a problem hiding this comment.
Keep the new shape validation inside the existing error path.
validate_shape() now runs before the try, so a bad request raises out of process_request(), skips DiffusionResponse(error_msg=...), and can take down the worker instead of failing just that request.
Proposed fix
def process_request(self, req: DiffusionRequest):
"""Process a single request."""
- self.pipeline.validate_shape(req.height, req.width, req.num_frames)
-
- if (
- self.pipeline._warmed_up_shapes
- and (req.height, req.width, req.num_frames) not in self.pipeline._warmed_up_shapes
- ):
- logger.warning(
- f"Requested shape ({req.height}x{req.width}, {req.num_frames} frames) "
- f"was not warmed up. First request will be slower due to "
- f"torch.compile recompilation. "
- f"Warmed-up shapes: {self.pipeline._warmed_up_shapes}"
- )
try:
+ self.pipeline.validate_shape(req.height, req.width, req.num_frames)
+
+ if (
+ self.pipeline._warmed_up_shapes
+ and (req.height, req.width, req.num_frames) not in self.pipeline._warmed_up_shapes
+ ):
+ logger.warning(
+ f"Requested shape ({req.height}x{req.width}, {req.num_frames} frames) "
+ f"was not warmed up. First request will be slower due to "
+ f"torch.compile recompilation. "
+ f"Warmed-up shapes: {self.pipeline._warmed_up_shapes}"
+ )
+
output = self.pipeline.infer(req)
if self.rank == 0:
self.response_queue.put(DiffusionResponse(request_id=req.request_id, output=output))
except Exception as e:🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/visual_gen/executor.py` around lines 174 - 187, Move the
call to validate_shape into the existing request error-handling path so shape
validation exceptions are caught and converted into a DiffusionResponse error
instead of bubbling out of process_request; specifically, in process_request,
call self.pipeline.validate_shape(req.height, req.width, req.num_frames) inside
the try block (the same block that constructs and returns DiffusionResponse on
error) rather than before it, ensuring any raised errors are handled by the
existing except/error response logic and preserving the warmed-up-shapes warning
logic around the now-validated request.
| @property | ||
| def common_warmup_shapes(self) -> list: | ||
| """Return list of common warmup shapes (height, width, num_frames).""" | ||
| return [(1024, 1024, 1)] | ||
| def default_warmup_resolutions(self): | ||
| return [(1024, 1024)] | ||
|
|
||
| @property | ||
| def default_warmup_num_frames(self): | ||
| return [1] | ||
|
|
||
| def is_valid_frame_count(self, num_frames): | ||
| return num_frames == 1 |
There was a problem hiding this comment.
Add the FLUX.1 resolution multiple override.
This pipeline now participates in validate_shape(), but only the frame-count constraint is declared. The latent pack/unpack path still quantizes height and width to self.vae_scale_factor * 2, so non-multiple resolutions can pass validation and produce a smaller image than requested instead of failing fast.
Proposed fix
`@property`
def default_warmup_num_frames(self):
return [1]
+ `@property`
+ def resolution_multiple_of(self):
+ multiple = self.vae_scale_factor * 2
+ return (multiple, multiple)
+
def is_valid_frame_count(self, num_frames):
return num_frames == 1🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py` around lines 101
- 110, The pipeline currently only enforces frame-count via is_valid_frame_count
but must also enforce the FLUX.1 resolution multiple: update validation (e.g.,
add or extend a method used by validate_shape) to check that requested height
and width are multiples of (self.vae_scale_factor * 2) and fail validation
otherwise; also update default_warmup_resolutions if needed to remain valid
under that rule. Reference the existing symbols default_warmup_resolutions,
default_warmup_num_frames, is_valid_frame_count, validate_shape, and
self.vae_scale_factor when implementing the check so non-multiple resolutions
are rejected early instead of being silently quantized.
| @property | ||
| def common_warmup_shapes(self) -> list: | ||
| """Return list of common warmup shapes (height, width, num_frames).""" | ||
| return [(1024, 1024, 1)] | ||
| def default_warmup_resolutions(self): | ||
| return [(1024, 1024)] | ||
|
|
||
| @property | ||
| def default_warmup_num_frames(self): | ||
| return [1] | ||
|
|
||
| def is_valid_frame_count(self, num_frames): | ||
| return num_frames == 1 |
There was a problem hiding this comment.
Declare FLUX.2's spatial constraint in the new validation hooks.
validate_shape() now relies on subclass-declared constraints, but this pipeline only overrides frame count. _prepare_latents() / _prepare_latent_ids() still floor height and width to the packed latent grid, so a request like 1025x1024 will pass validation and effectively run at 1024x1024 instead of failing fast.
Proposed fix
`@property`
def default_warmup_num_frames(self):
return [1]
+ `@property`
+ def resolution_multiple_of(self):
+ multiple = self.vae_scale_factor * 2
+ return (multiple, multiple)
+
def is_valid_frame_count(self, num_frames):
return num_frames == 1🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py` around lines
169 - 178, The pipeline currently only declares frame-count constraints
(default_warmup_num_frames / is_valid_frame_count) but not spatial constraints,
so validate_shape() allows inputs that will be floored by _prepare_latents()/
_prepare_latent_ids; add a spatial-validation hook on the class (e.g., implement
is_valid_spatial_resolution(self, height, width) or override validate_shape to
check spatial constraints) that returns True only for resolutions that align
with the packed latent grid (match entries in default_warmup_resolutions or are
exact multiples of the latent packing stride used by
_prepare_latents/_prepare_latent_ids), so requests like 1025x1024 are rejected
rather than silently floored. Ensure the new method uses the same latent packing
stride/scale as _prepare_latents/_prepare_latent_ids and integrate it with
existing validate_shape() logic.
| f"Adjusting ({height}, {width}) -> ({calc_height}, {calc_width})." | ||
| ) | ||
| height, width = calc_height, calc_width | ||
| self.validate_shape(height, width, num_frames) |
There was a problem hiding this comment.
Reject last_image requests with fewer than 2 frames.
This only validates the lattice constraints. When last_image is set, _prepare_latents() later allocates num_frames - 2 intermediate frames, so num_frames == 1 still gets through here and then fails with a negative-dimension error instead of the intended fail-fast ValueError.
💡 Proposed fix
self.validate_shape(height, width, num_frames)
+ if last_image is not None and num_frames < 2:
+ raise ValueError(
+ "Wan I2V interpolation requires num_frames >= 2 when last_image is provided."
+ )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| self.validate_shape(height, width, num_frames) | |
| self.validate_shape(height, width, num_frames) | |
| if last_image is not None and num_frames < 2: | |
| raise ValueError( | |
| "Wan I2V interpolation requires num_frames >= 2 when last_image is provided." | |
| ) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py` at line 456,
The code currently calls self.validate_shape(height, width, num_frames) but does
not reject requests where last_image is True and num_frames < 2; add an explicit
guard before preparing latents so that if self.last_image (or the
caller-provided last_image flag) is set and num_frames < 2 you raise a
ValueError with a clear message; adjust the logic either in validate_shape or
immediately around the call site that leads into _prepare_latents() so that
_prepare_latents() never receives num_frames < 2 (it allocates num_frames - 2
intermediate frames).
| def validate_shape(self, height: int, width: int, num_frames: int) -> None: | ||
| """Validate shape against model constraints. Raises ValueError.""" | ||
| h_mul, w_mul = self.resolution_multiple_of | ||
| if h_mul > 1 or w_mul > 1: | ||
| if height % h_mul != 0 or width % w_mul != 0: | ||
| raise ValueError( | ||
| f"Resolution ({height}x{width}) must be multiples of " | ||
| f"({h_mul}x{w_mul}) for {self.__class__.__name__}." | ||
| ) | ||
| if not self.is_valid_frame_count(num_frames): | ||
| raise ValueError(f"Invalid num_frames={num_frames} for {self.__class__.__name__}.") |
There was a problem hiding this comment.
Reject non-positive dimensions before the model-specific checks.
The new fail-fast path still accepts impossible shapes like height <= 0, width <= 0, and some negative frame counts. For example, WAN's (num_frames - 1) % 4 == 0 predicate accepts -3, so those requests now get past validation and fail much later in latent allocation.
Proposed fix
def validate_shape(self, height: int, width: int, num_frames: int) -> None:
"""Validate shape against model constraints. Raises ValueError."""
+ if height <= 0 or width <= 0:
+ raise ValueError(f"Resolution must be positive, got ({height}x{width}).")
+ if num_frames <= 0:
+ raise ValueError(f"num_frames must be positive, got {num_frames}.")
+
h_mul, w_mul = self.resolution_multiple_of
if h_mul > 1 or w_mul > 1:
if height % h_mul != 0 or width % w_mul != 0:
raise ValueError(
f"Resolution ({height}x{width}) must be multiples of "🧰 Tools
🪛 Ruff (0.15.5)
[warning] 141-144: Avoid specifying long messages outside the exception class
(TRY003)
[warning] 146-146: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/visual_gen/pipeline.py` around lines 136 - 146, In
validate_shape, add an explicit fast-fail that rejects non-positive inputs
before the model-specific checks: verify height > 0, width > 0, and num_frames >
0 and raise a ValueError (including the class name via self.__class__.__name__)
if any are non-positive; keep this check above the resolution_multiple_of logic
and the is_valid_frame_count call so negative dimensions or frame counts cannot
pass model-specific predicates.
| for height, width, num_frames in shapes: | ||
| logger.info(f"Warmup: {height}x{width}, {num_frames} frames, {steps} steps") | ||
| self._run_warmup(height, width, num_frames, steps) | ||
| torch.cuda.synchronize() | ||
|
|
||
| torch.cuda.synchronize() | ||
| self._warmed_up_shapes = set(tuple(s) for s in shapes) |
There was a problem hiding this comment.
Guard torch.cuda.synchronize() for CPU warmup paths.
This runs unconditionally now, but the new warmup tests are CPU-only. On a CPU-only build or host, torch.cuda.synchronize() raises even though the warmup loop itself is otherwise backend-agnostic.
Proposed fix
for height, width, num_frames in shapes:
logger.info(f"Warmup: {height}x{width}, {num_frames} frames, {steps} steps")
self._run_warmup(height, width, num_frames, steps)
- torch.cuda.synchronize()
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for height, width, num_frames in shapes: | |
| logger.info(f"Warmup: {height}x{width}, {num_frames} frames, {steps} steps") | |
| self._run_warmup(height, width, num_frames, steps) | |
| torch.cuda.synchronize() | |
| torch.cuda.synchronize() | |
| self._warmed_up_shapes = set(tuple(s) for s in shapes) | |
| for height, width, num_frames in shapes: | |
| logger.info(f"Warmup: {height}x{width}, {num_frames} frames, {steps} steps") | |
| self._run_warmup(height, width, num_frames, steps) | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| self._warmed_up_shapes = set(tuple(s) for s in shapes) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tensorrt_llm/_torch/visual_gen/pipeline.py` around lines 397 - 402, The
unconditional call to torch.cuda.synchronize() after running _run_warmup causes
errors on CPU-only builds; guard the call with a CUDA availability check (e.g.,
torch.cuda.is_available() or torch.cuda.device_count() > 0) so synchronization
only runs when CUDA is present, leaving the warmup loop and setting of
_warmed_up_shapes unchanged.
…ive check - Move validate_shape() inside executor try/except so invalid shape errors are returned as DiffusionResponse instead of crashing worker - Add non-positive dimension check in validate_shape() (height/width/ num_frames must be > 0) - Revert Flux resolution_multiple_of (out of scope for this PR) Signed-off-by: Yiyun Lu <55233584+luyiyun1021@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #38575 [ run ] triggered by Bot. Commit: |
| enable_cuda_graph: bool = False | ||
|
|
||
|
|
||
| class WarmupConfig(StrictBaseModel): |
There was a problem hiding this comment.
And we can comment to the CUDA Graph and the torch compile config to mentioned that the shapes are configured here.
| """Check if frame count is valid for this model. Subclass override.""" | ||
| return True | ||
|
|
||
| def validate_shape(self, height: int, width: int, num_frames: int) -> None: |
There was a problem hiding this comment.
Warmup is not a hard blocker for running inference. We may log warning and ignore the invalid ones.
|
PR_Github #38575 [ run ] completed with state
|
| More warmup shapes = slower startup, but lower risk of torch.compile | ||
| recompilation delays on first requests. Fewer shapes = faster startup, | ||
| but first request with an un-warmed shape triggers recompilation. |
There was a problem hiding this comment.
we should be careful allowing users to add a lot of warmup shapes as the torch compile cache is limited
Might be good to check how many shapes can fit before it fills up the cache
@coderabbitai summary
Description
Make VisualGen warmup shapes configurable via a new
WarmupConfigsub-config, replacing hardcoded warmup shapes in model pipeline subclasses.Problem: Warmup shapes (resolution + frame count) were hardcoded in each pipeline class (e.g.,
[(480, 832, 33), (480, 832, 81), (720, 1280, 81)]for Wan). Users could not configure which shapes to pre-compile at startup. Requesting an un-warmed shape triggers torch.compile recompilation with seconds of delay.Solution:
WarmupConfigsub-config withresolutionsandnum_framesfields, combined via Cartesian product at warmup timeBasePipeline.resolve_warmup_plan()to resolve user config vs model defaults (user > model-default > empty)BasePipeline.validate_shape()with model constraint declarations (resolution_multiple_of,is_valid_frame_count)DiffusionExecutor.process_request()— fail-fast for invalid shapes, warning for un-warmed shapesforward()with explicitvalidate_shape()errorUser-facing YAML config:
Design decisions:
WarmupConfigas sub-config: follows the same pattern as existing 6 sub-configs (torch_compile,cuda_graph, etc.)resolutions × num_frames— simpler to configure than explicit (h, w, f) tripleswarmup_stepsnot exposed: stays as model-internal heuristic (default_warmup_steps)Backward compatibility:
warmupis not configuredforward()now raisesValueErroron invalid shapes instead of silently rounding (intentional)Test Coverage
tests/unittest/_torch/visual_gen/test_warmup.py— 28 new unit tests (CPU only):TestWarmupConfig(6): config construction, defaults, empty lists, strict validationTestVisualGenArgsWarmup(4): dict/YAML construction, pickle serializationTestResolveWarmupPlan(10): Cartesian product, priority chain, empty lists, invalid values, Flux constraintsTestValidateShape(7): resolution/frame validation, edge cases, Flux single-frame constraintTestWarmupExecution(3):_warmed_up_shapestracking, empty shapes skipTestRequestValidation(3): warmed/un-warmed/invalid shape behaviortests/unittest/_torch/visual_gen/test_visual_gen_args.py— 1 new test:WarmupConfigstrict validationExisting CI tests are not affected: all use
skip_warmup=Trueand do not reference removed APIs.PR Checklist
PR description clearly explains what and why.
PR Follows TRT-LLM CODING GUIDELINES.
Test cases are provided for new code paths.
Any new dependencies have been scanned for license and vulnerabilities.
CODEOWNERS updated if ownership changes.
Documentation updated as needed.
Update tava architecture diagram if significant design change.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.