[TRTLLM-11362][feat] Add batch generation support to visual gen pipelines#12121
[TRTLLM-11362][feat] Add batch generation support to visual gen pipelines#12121karljang wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR adds batch prompt generation support across visual generation pipelines (Flux, Flux2, WAN, WAN I2V) and related APIs by widening prompt parameter types to Changes
Sequence DiagramsequenceDiagram
participant Client
participant VisualGen
participant Pipeline
participant LatentProcessor
participant Decoder
participant Output
Client->>VisualGen: generate_async(prompts: Union[str, List[str]])
VisualGen->>VisualGen: Infer batch_size from prompt type
VisualGen->>Pipeline: forward(prompt, batch_size)
Pipeline->>Pipeline: batch_size = len(prompt) if list else 1
Pipeline->>Pipeline: Encode prompt(s) to embeddings
Pipeline->>LatentProcessor: _prepare_latents(batch_size, height, width)
LatentProcessor->>LatentProcessor: Create batch-aware latent shape (B, C, H', W')
LatentProcessor->>Pipeline: Return latents
Pipeline->>Pipeline: Model processing with batch dimension
Pipeline->>Decoder: _decode_latents(latents, batch_size)
Decoder->>Decoder: Process per-batch outputs
alt batch_size == 1
Decoder->>Decoder: Return single image (H, W, C)
else batch_size > 1
Decoder->>Decoder: Return batch tensor (B, H, W, C)
end
Decoder->>Pipeline: Return images
Pipeline->>Output: Return batch or single output
Output->>Client: Deliver result
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (4)
tests/unittest/_torch/visual_gen/test_wan_i2v.py (1)
1087-1100: This fixture only validates one Wan I2V variant at a time.
i2v_full_pipelineloads whichever single checkpointCHECKPOINT_PATHpoints to, and the module default is Wan2.2. That means the Wan 2.1-only batch branch intensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py—the newimage_embeds.repeat(batch_size, 1, 1)path—is untested unless CI swaps checkpoints. Please split or parameterize the fixture by variant.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_wan_i2v.py` around lines 1087 - 1100, The fixture i2v_full_pipeline currently loads a single checkpoint via CHECKPOINT_PATH so only the module default (Wan2.2) is exercised; update the fixture to cover both Wan variants by parameterizing or splitting it: make i2v_full_pipeline a parametrized pytest fixture (or add a second fixture) that iterates over two checkpoint choices or variant labels (e.g., "wan2.2" and "wan2.1") and for each instantiates VisualGenArgs with the corresponding checkpoint/variant and then calls PipelineLoader(args).load(skip_warmup=True) so the Wan2.1 path in tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py (the image_embeds.repeat(batch_size, 1, 1) branch) is exercised in tests; ensure the fixture skips when a given checkpoint is missing by checking existence of each checkpoint path like the current CHECKPOINT_PATH check.tests/unittest/_torch/visual_gen/test_visual_gen_args.py (1)
273-289: Add a conflicting-negative_promptbatch case here.This only exercises the happy path where one item provides
negative_prompt, so it won't catch the current silent-drop behavior when a later batch item supplies a different value. Oncegenerate_async()is fixed, please pin that validation here as well.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_visual_gen_args.py` around lines 273 - 289, Extend the test to cover a conflicting-negative_prompt case: in test_list_of_dicts_input (or a new test) call vg.generate_async with inputs where the first dict has "negative_prompt": "dark" and a later dict has a different "negative_prompt" (e.g., "light"), then assert that the call fails validation by wrapping the call in pytest.raises(ValueError) (or the project's ValidationError) and/or asserting that vg.executor.enqueue_requests was not invoked; reference generate_async and vg.executor.enqueue_requests to locate the code under test.tests/unittest/_torch/visual_gen/test_wan.py (2)
3473-3487: Prefix unusedTvariable with underscore.Static analysis flags
Tas unused. Use_Tto indicate the value is intentionally ignored while still documenting the shape.Proposed fix
assert result.video.dim() == 4, f"Expected 4D (T,H,W,C), got {result.video.dim()}D" - T, H, W, C = result.video.shape + _T, H, W, C = result.video.shape assert H == 480 and W == 832 and C == 3🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_wan.py` around lines 3473 - 3487, The test test_single_prompt_backward_compat in tests/unittest/_torch/visual_gen/test_wan.py defines T, H, W, C = result.video.shape but never uses T; rename the unused T to _T to satisfy static analysis and indicate intentional ignoring of the temporal dimension while keeping the shape unpacking (change the tuple target from T to _T in the test_single_prompt_backward_compat function).
3489-3504: Prefix unusedTvariable with underscore.Same issue as above—
Tis unpacked but never used. Use_Tto silence the warning.Proposed fix
assert result.video.dim() == 5, f"Expected 5D (B,T,H,W,C), got {result.video.dim()}D" - B, T, H, W, C = result.video.shape + B, _T, H, W, C = result.video.shape assert B == 2 and H == 480 and W == 832 and C == 3🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/test_wan.py` around lines 3489 - 3504, In test_batch_prompt_shape rename the unused unpacked variable T to _T to silence the unused-variable warning: update the tuple unpacking in the test_batch_prompt_shape function (B, T, H, W, C = result.video.shape) to use _T instead of T and leave the rest of the assertions unchanged so B, H, W, C assertions still operate on the correct values.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py`:
- Around line 507-510: The pre-repeat of image_embeds causes a shape mismatch
when CFG doubles latents in denoise(); inside forward_fn (the function that
builds latent_model_input / calls current_model) detect when
encoder_hidden_states_image (image_embeds) batch size differs from
latents_input/latent_model_input and repeat the image embeddings to match by
computing repeat_factor = latents_input.shape[0] // image_embeds.shape[0] (only
when divisible) and using image_embeds_to_use =
image_embeds.repeat(repeat_factor, 1, 1) before passing
encoder_hidden_states_image=image_embeds_to_use to current_model; this keeps
condition_data expansion logic intact and prevents relying on accidental
broadcasting in denoise()/forward_fn.
In `@tensorrt_llm/llmapi/visual_gen.py`:
- Around line 534-545: The batch-handling branch that builds prompt and
negative_prompt (variables inputs, prompt, negative_prompt) must validate each
item before collapsing: ensure every item is either a str or dict with a present
"prompt" key and reject/raise a clear API error on invalid types or missing
"prompt"; additionally detect per-item conflicting "negative_prompt" values and
either (a) reject the batch with a validation error if you want a single
negative_prompt per-request, or (b) propagate per-item negatives by extending
DiffusionRequest.negative_prompt to accept a list and attach corresponding
negatives instead of using only the first one; implement the chosen behavior in
the function that builds the DiffusionRequest so callers get deterministic
errors or per-sample negatives rather than silent drops.
---
Nitpick comments:
In `@tests/unittest/_torch/visual_gen/test_visual_gen_args.py`:
- Around line 273-289: Extend the test to cover a conflicting-negative_prompt
case: in test_list_of_dicts_input (or a new test) call vg.generate_async with
inputs where the first dict has "negative_prompt": "dark" and a later dict has a
different "negative_prompt" (e.g., "light"), then assert that the call fails
validation by wrapping the call in pytest.raises(ValueError) (or the project's
ValidationError) and/or asserting that vg.executor.enqueue_requests was not
invoked; reference generate_async and vg.executor.enqueue_requests to locate the
code under test.
In `@tests/unittest/_torch/visual_gen/test_wan_i2v.py`:
- Around line 1087-1100: The fixture i2v_full_pipeline currently loads a single
checkpoint via CHECKPOINT_PATH so only the module default (Wan2.2) is exercised;
update the fixture to cover both Wan variants by parameterizing or splitting it:
make i2v_full_pipeline a parametrized pytest fixture (or add a second fixture)
that iterates over two checkpoint choices or variant labels (e.g., "wan2.2" and
"wan2.1") and for each instantiates VisualGenArgs with the corresponding
checkpoint/variant and then calls PipelineLoader(args).load(skip_warmup=True) so
the Wan2.1 path in tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py
(the image_embeds.repeat(batch_size, 1, 1) branch) is exercised in tests; ensure
the fixture skips when a given checkpoint is missing by checking existence of
each checkpoint path like the current CHECKPOINT_PATH check.
In `@tests/unittest/_torch/visual_gen/test_wan.py`:
- Around line 3473-3487: The test test_single_prompt_backward_compat in
tests/unittest/_torch/visual_gen/test_wan.py defines T, H, W, C =
result.video.shape but never uses T; rename the unused T to _T to satisfy static
analysis and indicate intentional ignoring of the temporal dimension while
keeping the shape unpacking (change the tuple target from T to _T in the
test_single_prompt_backward_compat function).
- Around line 3489-3504: In test_batch_prompt_shape rename the unused unpacked
variable T to _T to silence the unused-variable warning: update the tuple
unpacking in the test_batch_prompt_shape function (B, T, H, W, C =
result.video.shape) to use _T instead of T and leave the rest of the assertions
unchanged so B, H, W, C assertions still operate on the correct values.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: e3ca3680-b29e-48d5-9490-7092f6461670
📒 Files selected for processing (10)
tensorrt_llm/_torch/visual_gen/executor.pytensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.pytensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.pytensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.pytensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.pytensorrt_llm/llmapi/visual_gen.pytests/unittest/_torch/visual_gen/test_flux_pipeline.pytests/unittest/_torch/visual_gen/test_visual_gen_args.pytests/unittest/_torch/visual_gen/test_wan.pytests/unittest/_torch/visual_gen/test_wan_i2v.py
c91da00 to
a003a3d
Compare
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
/bot run --disable-fail-fast |
a003a3d to
b88c1de
Compare
|
PR_Github #38632 [ run ] triggered by Bot. Commit: |
chang-l
left a comment
There was a problem hiding this comment.
I feel we may need some more work to properly enable batched inference for the Wan I2V task. Is that correct, @o-stoner @karljang ? For example, supporting multiple images within a single request, and batching requests with independent text/image inputs.
Maybe we can have another PR to enable batch generation with image input if I understand the current limitation correctly
| image_embeds = image_embeds.to(self.dtype) | ||
| # Repeat for batch: single image, multiple prompts | ||
| if batch_size > 1: | ||
| image_embeds = image_embeds.repeat(batch_size, 1, 1) |
There was a problem hiding this comment.
my understanding is a batch of requests may also contain different images, right?
There was a problem hiding this comment.
Good point. Actually current implementation is aligned with HF Diffuser batch support, meaning HF Diffuser WAN I2V pipeline supports only one image with multiple prompts~
There was a problem hiding this comment.
Oh.. diffusers does support multiple images. Let me check further.
There was a problem hiding this comment.
Confirmed using CC :)
So HF diffusers WAN I2V does NOT support multiple images. The PipelineImageInput type alias allows list[PIL.Image], but the WAN I2V check_inputs() rejects it. It's single image only — the type is shared across all diffusers pipelines but each pipeline validates differently.
There was a problem hiding this comment.
Thanks for checking @karljang, since this PR aligns with diffusers, I think probably it is fine to leave it as-is in this PR.
But more general, I think batching support should be extended to handle multiple independent requests in future, i.e., the image encoder would need to handle batched image inputs, although the throughput impact may be limited.
There was a problem hiding this comment.
btw, could we have some logging/warnings to users regarding this single-image multiple prompts/requests limitation?
b88c1de to
494685a
Compare
…ines Add batch inference support to all visual generation pipelines (FLUX.1, FLUX.2, WAN T2V, WAN I2V). A single forward() call now accepts a list of prompts and generates all outputs in parallel with proper CFG handling. - prompt parameter accepts Union[str, List[str]] across all pipelines - Single prompt returns original shape for backward compatibility - Seed behavior aligned with HF diffusers (single generator, batch_size in shape) - API-level support in DiffusionRequest and VisualGeneration.generate_async() - 16 new tests covering batch shape, backward compat, and API parsing Signed-off-by: Kanghwan Jang <861393+karljang@users.noreply.github.com>
494685a to
284b86f
Compare
|
PR_Github #38632 [ run ] completed with state
|
Summary by CodeRabbit
Release Notes
New Features
Tests
Summary
Add batch inference support to all visual generation pipelines (FLUX.1, FLUX.2, WAN T2V, WAN I2V). A single
forward()call can now accept a list of prompts and generate all outputs in parallel, with proper CFG (classifier-free guidance) handling for the full batch.promptparameter acceptsUnion[str, List[str]]across all pipelines(H,W,C)/(T,H,W,C)for backward compatibility(B,H,W,C)/(B,T,H,W,C)with batch dimension prependedbatch_sizein tensor shape (not per-sampleseed+i)DiffusionRequestandVisualGeneration.generate_async()Test Coverage
pytest tests/unittest/_torch/visual_gen/test_visual_gen_args.py— API-level batch input parsing (no GPU)pytest tests/unittest/_torch/visual_gen/test_flux_pipeline.py -k "batch"— FLUX batch generation (1x GPU)pytest tests/unittest/_torch/visual_gen/test_wan.py -k "batch"— WAN T2V batch generation (1x GPU)pytest tests/unittest/_torch/visual_gen/test_wan_i2v.py -k "batch"— WAN I2V batch generation (1x GPU)PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.