diff --git a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp index fc161ab4a6ca..b71c39d40874 100644 --- a/cpp/tensorrt_llm/nanobind/thop/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/thop/bindings.cpp @@ -70,8 +70,8 @@ void initBindings(nb::module_& m) nb::arg("cu_kv_seqlens") = std::nullopt, nb::arg("fmha_scheduler_counter") = std::nullopt, nb::arg("mla_bmm1_scale") = std::nullopt, nb::arg("mla_bmm2_scale") = std::nullopt, nb::arg("quant_q_buffer") = std::nullopt, nb::arg("flash_mla_tile_scheduler_metadata") = std::nullopt, - nb::arg("flash_mla_num_splits") = std::nullopt, "Multi-head attention operation", - nb::call_guard()); + nb::arg("flash_mla_num_splits") = std::nullopt, nb::arg("num_contexts") = 0, nb::arg("num_ctx_tokens") = 0, + "Multi-head attention operation", nb::call_guard()); m.def( "get_helix_workspace_size_per_rank", diff --git a/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp b/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp index 940d59258ca8..f5a1336ea3e1 100644 --- a/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp +++ b/cpp/tensorrt_llm/thop/IndexerKCacheScatterOp.cpp @@ -28,69 +28,66 @@ TRTLLM_NAMESPACE_BEGIN namespace torch_ext { -void indexer_k_cache_scatter_op(th::Tensor const& k_fp8_bytes, th::Tensor const& k_scale_bytes, th::Tensor& k_cache, - th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale) +void indexer_k_cache_scatter_op(th::Tensor const& k_fp8, th::Tensor const& k_scale, th::Tensor& k_cache, + th::Tensor const& slot_mapping_fp8, th::Tensor const& slot_mapping_scale, int64_t num_tokens) { - // Validate all tensors are CUDA tensors - TORCH_CHECK(k_fp8_bytes.is_cuda() && k_scale_bytes.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda() + // k_fp8: [>=num_tokens, head_dim] in FP8 (1 byte/element) — reinterpreted as uint8 + // k_scale: [>=num_tokens, head_dim // quant_block_size] in float32 — reinterpreted as uint8 bytes + // slot_mapping_fp8, slot_mapping_scale: [>=num_tokens] int64 — only first num_tokens used + // k_cache: [num_blocks, block_size, 1, per_token_size] uint8 + + TORCH_CHECK(k_fp8.is_cuda() && k_scale.is_cuda() && k_cache.is_cuda() && slot_mapping_fp8.is_cuda() && slot_mapping_scale.is_cuda(), "All tensors must be CUDA tensors"); // Validate tensor dimensions - TORCH_CHECK(k_fp8_bytes.dim() == 2, "k_fp8_bytes must be a 2D Tensor [num_tokens, head_dim]"); - TORCH_CHECK(k_scale_bytes.dim() == 2, "k_scale_bytes must be a 2D Tensor [num_tokens, scale_size]"); - TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be a 1D Tensor [num_tokens]"); - TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be a 1D Tensor [num_tokens]"); - - // Enforce k_cache is 4D tensor - TORCH_CHECK(k_cache.dim() == 4, - "k_cache must be a 4D Tensor [num_blocks, block_size, 1, per_token_size], got %d dimensions", + TORCH_CHECK(k_fp8.dim() == 2, "k_fp8 must be 2D [num_tokens, head_dim]"); + TORCH_CHECK(k_scale.dim() == 2, "k_scale must be 2D [num_tokens, scale_elements]"); + TORCH_CHECK(slot_mapping_fp8.dim() == 1, "slot_mapping_fp8 must be 1D [num_tokens]"); + TORCH_CHECK(slot_mapping_scale.dim() == 1, "slot_mapping_scale must be 1D [num_tokens]"); + TORCH_CHECK(k_cache.dim() == 4, "k_cache must be 4D [num_blocks, block_size, 1, per_token_size], got %d dims", static_cast(k_cache.dim())); - // Validate tensor dtypes - TORCH_CHECK(k_fp8_bytes.scalar_type() == torch::kUInt8, "k_fp8_bytes must be uint8"); - TORCH_CHECK(k_scale_bytes.scalar_type() == torch::kUInt8, "k_scale_bytes must be uint8"); + // Validate tensor dtypes — reinterpret_cast below assumes specific element sizes + TORCH_CHECK(k_fp8.element_size() == 1, "k_fp8 must have 1-byte elements (e.g. FP8), got %d", k_fp8.element_size()); + TORCH_CHECK(k_scale.element_size() == 4, "k_scale must have 4-byte elements (e.g. float32), got %d", + k_scale.element_size()); TORCH_CHECK(slot_mapping_fp8.scalar_type() == torch::kInt64, "slot_mapping_fp8 must be int64"); TORCH_CHECK(slot_mapping_scale.scalar_type() == torch::kInt64, "slot_mapping_scale must be int64"); - // Validate tensor shapes are consistent - auto num_tokens = static_cast(k_fp8_bytes.size(0)); - TORCH_CHECK( - k_scale_bytes.size(0) == num_tokens, "k_scale_bytes first dimension must equal k_fp8_bytes first dimension"); - TORCH_CHECK(slot_mapping_fp8.size(0) == num_tokens, "slot_mapping_fp8 length must equal num_tokens"); - TORCH_CHECK(slot_mapping_scale.size(0) == num_tokens, "slot_mapping_scale length must equal num_tokens"); - - // Validate tensors are contiguous (except k_cache which may be non-contiguous) - TORCH_CHECK(k_fp8_bytes.is_contiguous(), "k_fp8_bytes must be contiguous"); - TORCH_CHECK(k_scale_bytes.is_contiguous(), "k_scale_bytes must be contiguous"); - // k_cache can be non-contiguous - we handle this via strides + TORCH_CHECK(k_fp8.is_contiguous(), "k_fp8 must be contiguous"); + TORCH_CHECK(k_scale.is_contiguous(), "k_scale must be contiguous"); TORCH_CHECK(slot_mapping_fp8.is_contiguous(), "slot_mapping_fp8 must be contiguous"); TORCH_CHECK(slot_mapping_scale.is_contiguous(), "slot_mapping_scale must be contiguous"); - int32_t head_dim = static_cast(k_fp8_bytes.size(1)); // head_dim = quant_block_size = 128 - int32_t scale_size = static_cast(k_scale_bytes.size(1)); // scale_size = 4 bytes - - int32_t cache_dim_0 = static_cast(k_cache.size(0)); // num_blocks - int32_t cache_dim_1 = static_cast(k_cache.size(1)); // block_size - int32_t cache_dim_2 = static_cast(k_cache.size(2)); // num_kv_heads - int32_t cache_dim_3 = static_cast(k_cache.size(3)); // per_token_size - - // Validation for indexer k cache pool for DeepSeek-V3.2 constraints - TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1 for DeepSeek-V3.2, got %d", cache_dim_2); - TORCH_CHECK(head_dim == 128, "k_fp8_bytes head_dim must be 128 for DeepSeek-V3.2, got %d", head_dim); - TORCH_CHECK(scale_size == 4, "k_scale_bytes scale_size must be 4 bytes for DeepSeek-V3.2, got %d", scale_size); - - int64_t cache_stride_0 = static_cast(k_cache.stride(0)); - int64_t cache_stride_1 = static_cast(k_cache.stride(1)); - int64_t cache_stride_2 = static_cast(k_cache.stride(2)); - int64_t cache_stride_3 = static_cast(k_cache.stride(3)); - - auto stream = at::cuda::getCurrentCUDAStream(k_fp8_bytes.get_device()); - - tk::invokeIndexerKCacheScatter(k_fp8_bytes.data_ptr(), k_scale_bytes.data_ptr(), - k_cache.data_ptr(), slot_mapping_fp8.data_ptr(), slot_mapping_scale.data_ptr(), - num_tokens, head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, - cache_stride_1, cache_stride_2, cache_stride_3, stream); + // FP8 is 1 byte per element, so head_dim in elements == head_dim in bytes. + int32_t const head_dim = static_cast(k_fp8.size(1)); + // Scale size in bytes: num_scale_elements * bytes_per_element. + int32_t const scale_size = static_cast(k_scale.size(1)) * static_cast(k_scale.element_size()); + + int32_t const cache_dim_0 = static_cast(k_cache.size(0)); + int32_t const cache_dim_1 = static_cast(k_cache.size(1)); + int32_t const cache_dim_2 = static_cast(k_cache.size(2)); + int32_t const cache_dim_3 = static_cast(k_cache.size(3)); + + TORCH_CHECK(cache_dim_2 == 1, "k_cache dimension 2 must be 1, got %d", cache_dim_2); + TORCH_CHECK(head_dim == 128, "k_fp8 head_dim must be 128, got %d", head_dim); + TORCH_CHECK(scale_size == 4, "k_scale scale_size must be 4 bytes, got %d", scale_size); + + int64_t const cache_stride_0 = static_cast(k_cache.stride(0)); + int64_t const cache_stride_1 = static_cast(k_cache.stride(1)); + int64_t const cache_stride_2 = static_cast(k_cache.stride(2)); + int64_t const cache_stride_3 = static_cast(k_cache.stride(3)); + + auto stream = at::cuda::getCurrentCUDAStream(k_fp8.get_device()); + + // Reinterpret k_fp8 as uint8 bytes and k_scale as raw bytes via data_ptr. + // For slot mappings, use data_ptr directly — only the first num_tokens entries are read. + tk::invokeIndexerKCacheScatter(reinterpret_cast(k_fp8.data_ptr()), + reinterpret_cast(k_scale.data_ptr()), k_cache.data_ptr(), + slot_mapping_fp8.data_ptr(), slot_mapping_scale.data_ptr(), static_cast(num_tokens), + head_dim, scale_size, cache_dim_0, cache_dim_1, cache_dim_2, cache_dim_3, cache_stride_0, cache_stride_1, + cache_stride_2, cache_stride_3, stream); } } // namespace torch_ext @@ -100,8 +97,8 @@ TRTLLM_NAMESPACE_END TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( - "indexer_k_cache_scatter_op(Tensor k_fp8_bytes, Tensor k_scale_bytes, Tensor(a!) k_cache, " - "Tensor slot_mapping_fp8, Tensor slot_mapping_scale) -> ()"); + "indexer_k_cache_scatter_op(Tensor k_fp8, Tensor k_scale, Tensor(a!) k_cache, " + "Tensor slot_mapping_fp8, Tensor slot_mapping_scale, int num_tokens) -> ()"); } TORCH_LIBRARY_IMPL(trtllm, CUDA, m) diff --git a/cpp/tensorrt_llm/thop/attentionOp.cpp b/cpp/tensorrt_llm/thop/attentionOp.cpp index 9a7af4da49f6..b526310564e9 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.cpp +++ b/cpp/tensorrt_llm/thop/attentionOp.cpp @@ -630,7 +630,8 @@ void attention(torch::Tensor q, std::optional k, std::optional cu_q_seqlens, std::optional cu_kv_seqlens, std::optional fmha_scheduler_counter, std::optional mla_bmm1_scale, std::optional mla_bmm2_scale, std::optional quant_q_buffer, - std::optional flash_mla_tile_scheduler_metadata, std::optional flash_mla_num_splits) + std::optional flash_mla_tile_scheduler_metadata, std::optional flash_mla_num_splits, + int64_t num_contexts, int64_t num_ctx_tokens) { TLLM_LOG_TRACE("Attention op starts at layer %d", layer_idx); // Use these tensors to infer if the attention is using KV cache @@ -833,20 +834,9 @@ void attention(torch::Tensor q, std::optional k, std::optional(num_contexts); int32_t const num_tokens = qkv_or_q.size(0); - int32_t const num_ctx_tokens = host_context_lengths.slice(0, 0, num_contexts).sum().item(); - int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - num_ctx_tokens; + int32_t const num_gen_tokens = is_gen_only ? num_tokens : num_tokens - static_cast(num_ctx_tokens); auto const ctx_total_kv_len = host_total_kv_lens.index({0}).item(); auto const gen_total_kv_len = host_total_kv_lens.index({1}).item(); diff --git a/cpp/tensorrt_llm/thop/attentionOp.h b/cpp/tensorrt_llm/thop/attentionOp.h index 0fc4788d6f0b..cc2b3f787f07 100644 --- a/cpp/tensorrt_llm/thop/attentionOp.h +++ b/cpp/tensorrt_llm/thop/attentionOp.h @@ -78,7 +78,8 @@ void attention(torch::Tensor q, std::optional k, std::optional fmha_scheduler_counter, std::optional mla_bmm1_scale, std::optional mla_bmm2_scale, std::optional quant_q_buffer, std::optional flash_mla_tile_scheduler_metadata = std::nullopt, - std::optional flash_mla_num_splits = std::nullopt); + std::optional flash_mla_num_splits = std::nullopt, int64_t num_contexts = 0, + int64_t num_ctx_tokens = 0); struct KvCachePoolPointers { diff --git a/examples/visual_gen/quickstart_example.py b/examples/visual_gen/quickstart_example.py index f738f9f52186..d608e3d4138a 100644 --- a/examples/visual_gen/quickstart_example.py +++ b/examples/visual_gen/quickstart_example.py @@ -20,7 +20,7 @@ def main(): inputs="A cat sitting on a windowsill", params=params, ) - MediaStorage.save_video(output.video, "output.avi", frame_rate=params.frame_rate) + MediaStorage.save_video(output.video, "output.avi", frame_rate=16.0) if __name__ == "__main__": diff --git a/examples/visual_gen/visual_gen_ltx2.py b/examples/visual_gen/visual_gen_ltx2.py index 981b7bc4d35b..c2ee0f2774f5 100755 --- a/examples/visual_gen/visual_gen_ltx2.py +++ b/examples/visual_gen/visual_gen_ltx2.py @@ -288,10 +288,18 @@ def main(): start_time = time.time() - inputs = { - "prompt": args.prompt, - "negative_prompt": args.negative_prompt, + inputs = {"prompt": args.prompt} + + extra_params = { + "guidance_rescale": args.guidance_rescale, + "stg_scale": args.stg_scale, + "modality_scale": args.modality_scale, + "rescale_scale": args.rescale_scale, + "guidance_skip_step": args.guidance_skip_step, + "enhance_prompt": args.enhance_prompt, } + if args.stg_blocks is not None: + extra_params["stg_blocks"] = args.stg_blocks params = VisualGenParams( height=args.height, @@ -302,15 +310,10 @@ def main(): seed=args.seed, num_frames=args.num_frames, frame_rate=args.frame_rate, - guidance_rescale=args.guidance_rescale, - input_reference=args.image, + negative_prompt=args.negative_prompt, + image=args.image, image_cond_strength=args.image_cond_strength, - stg_scale=args.stg_scale, - stg_blocks=args.stg_blocks, - modality_scale=args.modality_scale, - rescale_scale=args.rescale_scale, - guidance_skip_step=args.guidance_skip_step, - enhance_prompt=args.enhance_prompt, + extra_params=extra_params, ) output = visual_gen.generate(inputs=inputs, params=params) diff --git a/examples/visual_gen/visual_gen_wan_i2v.py b/examples/visual_gen/visual_gen_wan_i2v.py index 5356264ebb7a..56ef5a3f922c 100644 --- a/examples/visual_gen/visual_gen_wan_i2v.py +++ b/examples/visual_gen/visual_gen_wan_i2v.py @@ -224,11 +224,16 @@ def main(): start_time = time.time() + extra_params = {} + if args.last_image_path: + extra_params["last_image"] = args.last_image_path + if args.guidance_scale_2 is not None: + extra_params["guidance_scale_2"] = args.guidance_scale_2 + if args.boundary_ratio is not None: + extra_params["boundary_ratio"] = args.boundary_ratio + output = visual_gen.generate( - inputs={ - "prompt": args.prompt, - "negative_prompt": args.negative_prompt, - }, + inputs={"prompt": args.prompt}, params=VisualGenParams( height=args.height, width=args.width, @@ -236,10 +241,9 @@ def main(): guidance_scale=args.guidance_scale, seed=args.seed, num_frames=args.num_frames, - input_reference=args.image_path, - last_image=args.last_image_path if args.last_image_path else None, - guidance_scale_2=args.guidance_scale_2, - boundary_ratio=args.boundary_ratio, + negative_prompt=args.negative_prompt, + image=args.image_path, + extra_params=extra_params if extra_params else None, ), ) diff --git a/examples/visual_gen/visual_gen_wan_t2v.py b/examples/visual_gen/visual_gen_wan_t2v.py index 73a511aff4b7..70ca486a74c5 100755 --- a/examples/visual_gen/visual_gen_wan_t2v.py +++ b/examples/visual_gen/visual_gen_wan_t2v.py @@ -230,11 +230,14 @@ def main(): start_time = time.time() + extra_params = {} + if args.guidance_scale_2 is not None: + extra_params["guidance_scale_2"] = args.guidance_scale_2 + if args.boundary_ratio is not None: + extra_params["boundary_ratio"] = args.boundary_ratio + output = visual_gen.generate( - inputs={ - "prompt": args.prompt, - "negative_prompt": args.negative_prompt, - }, + inputs={"prompt": args.prompt}, params=VisualGenParams( height=args.height, width=args.width, @@ -242,8 +245,8 @@ def main(): guidance_scale=args.guidance_scale, seed=args.seed, num_frames=args.num_frames, - guidance_scale_2=args.guidance_scale_2, - boundary_ratio=args.boundary_ratio, + negative_prompt=args.negative_prompt, + extra_params=extra_params if extra_params else None, ), ) diff --git a/tensorrt_llm/__init__.py b/tensorrt_llm/__init__.py index 1e7618500bb0..375c8938cdd1 100644 --- a/tensorrt_llm/__init__.py +++ b/tensorrt_llm/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -132,8 +132,9 @@ def _setup_vendored_triton_kernels(): from .python_plugin import PluginBase from .sampling_params import SamplingParams from .version import __version__ -from .visual_gen import (VisualGen, VisualGenArgs, VisualGenError, - VisualGenParams, VisualGenResult) +from .visual_gen import (ExtraParamSchema, VisualGen, VisualGenArgs, + VisualGenError, VisualGenParams, VisualGenParamsError, + VisualGenResult) __all__ = [ 'AutoConfig', @@ -182,7 +183,9 @@ def _setup_vendored_triton_kernels(): 'TrtLlmArgs', 'SamplingParams', 'VisualGenArgs', + 'ExtraParamSchema', 'VisualGenError', + 'VisualGenParamsError', 'VisualGenResult', 'DisaggregatedParams', 'KvCacheConfig', diff --git a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py index 2e3428838725..3c6b8dd4cf4f 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/dsa.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/dsa.py @@ -1458,34 +1458,18 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor, if metadata.kv_cache_manager is None or metadata.slot_mapping_fp8 is None: return - # [num_blocks, block_size, 1, per_token_size ] k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers( self.layer_idx) num_tokens = k_fp8.shape[0] - head_dim = k_fp8.shape[1] - scale_size = k_scale.shape[1] * 4 # Convert to bytes (float32 = 4 bytes) - - # Convert to bytes: flatten first, then view as uint8, then reshape - k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view( - num_tokens, head_dim) - - # k_scale: for single-element tensors, contiguous() may be no-op - # Fix stride(-1) for byte-level view - k_scale_flat = k_scale.view(-1) - if k_scale_flat.stride(-1) != 1: - k_scale_flat = torch.as_strided(k_scale_flat.contiguous(), - size=(k_scale_flat.numel(), ), - stride=(1, )) - k_scale_bytes = k_scale_flat.view(torch.uint8).view( - num_tokens, scale_size) - - # Use CUDA kernel to scatter FP8 and scale bytes into cache - flat_indices_fp8 = metadata.slot_mapping_fp8[:num_tokens] - flat_indices_scale = metadata.slot_mapping_scale[:num_tokens] - torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes, - k_cache, flat_indices_fp8, - flat_indices_scale) + + # The C++ op reinterprets k_fp8 (FP8) and k_scale (float32) as raw + # bytes internally and only reads the first num_tokens entries from + # the slot mapping buffers, avoiding Python-side view/slice overhead. + torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8, k_scale, k_cache, + metadata.slot_mapping_fp8, + metadata.slot_mapping_scale, + num_tokens) def sparse_attn_indexer( self, diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 541c53b59bee..9f32039f145f 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -406,6 +406,8 @@ def run( mla_bmm1_scale: Optional[torch.Tensor] = None, mla_bmm2_scale: Optional[torch.Tensor] = None, quant_q_buffer: Optional[torch.Tensor] = None, + num_contexts: int = 0, + num_ctx_tokens: int = 0, ): """ Run the attention operation. @@ -652,6 +654,8 @@ def run( quant_q_buffer, self.quant_config, self.kv_cache_manager, + num_contexts, + num_ctx_tokens, global_layer_idx=self.global_layer_idx, ) else: @@ -736,6 +740,8 @@ def run( quant_q_buffer, self.flash_mla_tile_scheduler_metadata, self.flash_mla_num_splits, + num_contexts=num_contexts, + num_ctx_tokens=num_ctx_tokens, ) if self.print_skip_softmax_stat: @@ -2087,7 +2093,9 @@ def forward( fmha_scheduler_counter=fmha_scheduler_counter, mla_bmm1_scale=mla_bmm1_scale, mla_bmm2_scale=mla_bmm2_scale, - quant_q_buffer=quant_q_buffer) + quant_q_buffer=quant_q_buffer, + num_contexts=metadata.num_contexts, + num_ctx_tokens=metadata.num_ctx_tokens) if output_sf is None: return output diff --git a/tensorrt_llm/_torch/attention_backend/trtllm_gen.py b/tensorrt_llm/_torch/attention_backend/trtllm_gen.py index 439831d4bf7d..ad53a749da27 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm_gen.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm_gen.py @@ -1437,23 +1437,6 @@ def run_mla_generation(self, params: EnqueueGenerationParams) -> None: params.context_buf.copy_(mla_out.reshape_as(params.context_buf)) -def _parse_request_types(host_request_types: torch.Tensor) -> Tuple[int, int]: - """ - Parse request types to count context and generation requests. - - Args: - host_request_types: Request types tensor (0=context, 1=generation). - num_seqs: Total number of sequences. - - Returns: - Tuple of (num_contexts, num_generations). - """ - - num_generations = host_request_types.sum().item() - num_contexts = host_request_types.size(0) - num_generations - return num_contexts, num_generations - - def is_supported( q: torch.Tensor, num_heads: int, @@ -1636,6 +1619,8 @@ def trtllm_gen_attention( quant_q_buffer: Optional[torch.Tensor], quant_config: Optional[QuantConfig], kv_cache_manager: Optional[KVCacheManager], + num_contexts: int, + num_ctx_tokens: int, global_layer_idx: Optional[int] = None, ) -> None: """ @@ -1766,20 +1751,10 @@ def trtllm_gen_attention( if attention_input_type is not None: attn_input_type = AttentionInputType(attention_input_type) - num_contexts, num_generations = _parse_request_types(host_request_types) - is_gen_only = attn_input_type == AttentionInputType.generation_only - is_ctx_only = attn_input_type == AttentionInputType.context_only - - if is_gen_only: - num_ctx_tokens = 0 - num_gen_tokens = num_tokens - elif is_ctx_only: - num_ctx_tokens = num_tokens - num_gen_tokens = 0 - else: - num_ctx_tokens = int(host_context_lengths[:num_contexts].sum()) if num_contexts > 0 else 0 - num_gen_tokens = num_tokens - num_ctx_tokens + + num_generations = host_request_types.size(0) - num_contexts + num_gen_tokens = num_tokens - num_ctx_tokens # Prepare Workspace # Use upper-bound token counts for workspace sizing to avoid repeated diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py index f1c99267ed02..d7b7a1a4352c 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py @@ -82,6 +82,9 @@ def __init__(self): # keeping a separate copy here since we sometimes have to overwrite the original values self.host_past_kv_lengths: Optional[torch.Tensor] = None # [max_batch] int32 pinned self.host_context_lengths: Optional[torch.Tensor] = None # [max_batch] int32 pinned + # Batch counts for thop.attention (updated every forward in plan_host) + self.num_contexts: int = 0 + self.num_ctx_tokens: int = 0 # Persistent block_offsets buffer for CUDA graph compatibility. # Pre-allocated to max size so the tensor address is stable across replays. self.block_offsets: Optional[torch.Tensor] = None @@ -171,6 +174,10 @@ def plan_host( """ num_seq = num_prefill + num_decode + # Batch counts for thop.attention + self.num_contexts = num_prefill + self.num_ctx_tokens = int(seq_len_host[:num_prefill].sum()) if num_prefill > 0 else 0 + # host_request_types: 0 = prefill (context), 1 = decode (generation) self.host_request_types[:num_prefill].fill_(0) self.host_request_types[num_prefill:num_seq].fill_(1) @@ -500,6 +507,10 @@ def trtllm_mha_with_cache( None, # mla_bmm1_scale None, # mla_bmm2_scale None, # quant_q_buffer + None, # flash_mla_tile_scheduler_metadata + None, # flash_mla_num_splits + num_contexts=_GlobalTrtllmPlanner.num_contexts, + num_ctx_tokens=_GlobalTrtllmPlanner.num_ctx_tokens, ) if out is not None: diff --git a/tensorrt_llm/_torch/visual_gen/executor.py b/tensorrt_llm/_torch/visual_gen/executor.py index 3062de13fd6b..b773eec70f61 100644 --- a/tensorrt_llm/_torch/visual_gen/executor.py +++ b/tensorrt_llm/_torch/visual_gen/executor.py @@ -18,45 +18,39 @@ @dataclass class DiffusionRequest: - """Request for diffusion inference with explicit model-specific parameters.""" + """Request for diffusion inference. + + Universal parameters are top-level fields with ``None`` meaning + "use model default" (resolved by the executor before calling + ``pipeline.infer()``). Model-specific parameters live in + ``extra_params`` and are passed through to the pipeline. + """ request_id: int prompt: List[str] negative_prompt: Optional[str] = None - height: int = 720 - width: int = 1280 - num_inference_steps: int = 50 - guidance_scale: float = 5.0 - max_sequence_length: int = 512 + + # Core — None means "use model default" (resolved by executor) + height: Optional[int] = None + width: Optional[int] = None + num_inference_steps: Optional[int] = None + guidance_scale: Optional[float] = None + max_sequence_length: Optional[int] = None seed: int = 42 - # Video-specific parameters - num_frames: int = 81 - frame_rate: float = 24.0 + # Video + num_frames: Optional[int] = None + frame_rate: Optional[float] = None - # Image-specific parameters + # Image num_images_per_prompt: int = 1 - # Advanced parameters - guidance_rescale: float = 0.0 - output_type: str = "pt" + # Conditioning inputs + image: Optional[Union[str, bytes, List[Union[str, bytes]]]] = None + image_cond_strength: Optional[float] = None - # LTX-2 multi-modal guidance (STG / modality guidance) - stg_scale: float = 0.0 - stg_blocks: Optional[List[int]] = None - modality_scale: float = 1.0 - rescale_scale: float = 0.0 - guidance_skip_step: int = 0 - enhance_prompt: bool = False - - # Image-to-video parameters - image: Optional[Union[str, List[str]]] = None - image_cond_strength: float = 1.0 - - # Wan-specific parameters - guidance_scale_2: Optional[float] = None - boundary_ratio: Optional[float] = None - last_image: Optional[Union[str, List[str]]] = None + # Model-specific overflow (from VisualGenParams.extra_params) + extra_params: Optional[dict] = None @dataclass @@ -74,6 +68,33 @@ class DiffusionResponse: error_msg: Optional[str] = None +# Python type name → accepted Python types for ExtraParamSchema validation. +_TYPE_MAP = { + "float": (float, int), + "int": (int,), + "bool": (bool,), + "str": (str,), + "list": (list,), +} + +# Generation config fields that pipelines declare defaults for. +# If a user sets one of these but the pipeline doesn't declare it in +# DEFAULT_GENERATION_PARAMS, the value will be silently ignored. +# Conditioning inputs (image, negative_prompt, mask, image_cond_strength) +# are excluded — they are validated at runtime by the pipeline's infer(). +_GENERATION_CONFIG_FIELDS = frozenset( + { + "height", + "width", + "num_inference_steps", + "guidance_scale", + "max_sequence_length", + "num_frames", + "frame_rate", + } +) + + class DiffusionExecutor: """Execution engine for diffusion models running in worker processes.""" @@ -157,10 +178,19 @@ def _load_pipeline(self): # Sync all workers dist.barrier() - # Send READY signal + # Send READY signal with pipeline metadata for the client. if self.rank == 0: logger.info(f"Worker {self.device_id}: Sending READY") - self.response_queue.put(DiffusionResponse(request_id=-1, output="READY")) + self.response_queue.put( + DiffusionResponse( + request_id=-1, + output={ + "status": "READY", + "default_generation_params": self.pipeline.DEFAULT_GENERATION_PARAMS, + "extra_param_specs": self.pipeline.EXTRA_PARAM_SPECS, + }, + ) + ) def serve_forever(self): """Main execution loop.""" @@ -185,17 +215,115 @@ def serve_forever(self): logger.info(f"Worker {self.device_id}: Processing request {req.request_id}") self.process_request(req) + def _merge_defaults(self, req: DiffusionRequest): + """Fill ``None`` fields in *req* with pipeline-specific defaults. + + Merges both universal defaults (from ``default_generation_params``) + and extra_param defaults (from ``extra_param_specs``). + """ + # Universal field defaults + for field_name, default_value in self.pipeline.DEFAULT_GENERATION_PARAMS.items(): + if hasattr(req, field_name) and getattr(req, field_name) is None: + setattr(req, field_name, default_value) + + # Extra param defaults — fill all declared keys so infer() can use direct access + specs = self.pipeline.EXTRA_PARAM_SPECS + if specs: + if req.extra_params is None: + req.extra_params = {} + for key, spec in specs.items(): + if key not in req.extra_params: + req.extra_params[key] = spec.default + + self._validate_request(req) + + def _validate_request(self, req: DiffusionRequest): + """Validate *req* against the loaded pipeline's declared parameters. + + Raises ``VisualGenParamsError`` on: + - Unknown ``extra_params`` keys + - Universal fields (e.g. ``num_frames``) set by the user but not + declared in the pipeline's ``DEFAULT_GENERATION_PARAMS`` + - Type mismatches for ``extra_params`` values + - Out-of-range ``extra_params`` values + """ + # Lazy import to avoid circular dependency + # (executor → visual_gen.visual_gen → _torch.visual_gen → executor) + from tensorrt_llm.visual_gen.visual_gen import VisualGenParamsError + + errors: list[str] = [] + pipeline_name = self.pipeline.__class__.__name__ + declared_defaults = self.pipeline.DEFAULT_GENERATION_PARAMS + specs = self.pipeline.EXTRA_PARAM_SPECS + + # --- unknown extra_params keys --- + if req.extra_params: + unknown = set(req.extra_params.keys()) - set(specs.keys()) + if unknown: + errors.append( + f"Unknown extra_params {sorted(unknown)} for {pipeline_name}. " + f"Supported: {sorted(specs.keys())}" + ) + + # --- unsupported universal fields --- + # Check generation config fields the user explicitly set (not None) + # that the pipeline never declared in DEFAULT_GENERATION_PARAMS. + # Conditioning inputs (image, negative_prompt, mask) are excluded — + # they are validated at runtime by the pipeline's infer(). + for field_name in _GENERATION_CONFIG_FIELDS: + value = getattr(req, field_name, None) + if value is not None and field_name not in declared_defaults: + errors.append( + f"Parameter '{field_name}' is set but {pipeline_name} does " + f"not use it (not in DEFAULT_GENERATION_PARAMS). " + f"It will be silently ignored." + ) + + # --- extra_params type and range checks --- + if req.extra_params: + for key, value in req.extra_params.items(): + if key not in specs: + continue # already reported as unknown above + spec = specs[key] + # Skip None values (param left at its None default) + if value is None: + continue + # Type check + expected_types = _TYPE_MAP.get(spec.type) + if expected_types and not isinstance(value, expected_types): + errors.append( + f"extra_params['{key}'] expected type '{spec.type}', " + f"got {type(value).__name__}: {value!r}" + ) + continue # skip range check if type is wrong + # Range check (numeric only) + if spec.range is not None and isinstance(value, (int, float)): + lo, hi = spec.range + if not (lo <= value <= hi): + errors.append( + f"extra_params['{key}'] value {value} is out of range [{lo}, {hi}]" + ) + + if errors: + msg = f"Parameter validation failed for {pipeline_name}:\n" + "\n".join( + f" - {e}" for e in errors + ) + raise VisualGenParamsError(msg) + def process_request(self, req: DiffusionRequest): """Process a single request.""" - cache_key = self.pipeline.warmup_cache_key(req.height, req.width, num_frames=req.num_frames) - if self.pipeline._warmed_up_shapes and cache_key not in self.pipeline._warmed_up_shapes: - logger.warning( - f"Requested shape {cache_key} was not warmed up. " - f"First request with this shape will be slower due to " - f"torch.compile recompilation or CUDA graph capture. " - f"Warmed-up shapes: {self.pipeline._warmed_up_shapes}" - ) try: + self._merge_defaults(req) + cache_key = self.pipeline.warmup_cache_key( + req.height, req.width, num_frames=req.num_frames + ) + if self.pipeline._warmed_up_shapes and cache_key not in self.pipeline._warmed_up_shapes: + logger.warning( + f"Requested shape {cache_key} was not warmed up. " + f"First request with this shape will be slower due to " + f"torch.compile recompilation or CUDA graph capture. " + f"Warmed-up shapes: {self.pipeline._warmed_up_shapes}" + ) output = self.pipeline.infer(req) if self.rank == 0: self.response_queue.put(DiffusionResponse(request_id=req.request_id, output=output)) diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py index 457c90fb3e25..d1c3f6a8f4d1 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux.py @@ -229,6 +229,14 @@ def post_load_weights(self) -> None: # Enable TeaCache with FLUX.1-specific polynomial coefficients self._setup_teacache(self.transformer, FLUX_TEACACHE_COEFFICIENTS) + DEFAULT_GENERATION_PARAMS = { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "guidance_scale": 3.5, + "max_sequence_length": 512, + } + def infer(self, req): """Run inference from DiffusionRequest.""" return self.forward( diff --git a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py index 522421124ad4..a3a2b41b5cee 100644 --- a/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py +++ b/tensorrt_llm/_torch/visual_gen/models/flux/pipeline_flux2.py @@ -322,6 +322,14 @@ def post_load_weights(self) -> None: # Enable TeaCache with FLUX.2-specific polynomial coefficients self._setup_teacache(self.transformer, FLUX2_TEACACHE_COEFFICIENTS) + DEFAULT_GENERATION_PARAMS = { + "height": 1024, + "width": 1024, + "num_inference_steps": 50, + "guidance_scale": 3.5, + "max_sequence_length": 512, + } + def infer(self, req): """Run inference from DiffusionRequest.""" return self.forward( diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py index 6c5e63683afd..571a7d59d21d 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2.py @@ -18,7 +18,7 @@ from tensorrt_llm._torch.visual_gen.config import PipelineComponent from tensorrt_llm._torch.visual_gen.cuda_graph_runner import CUDAGraphRunner, CUDAGraphRunnerConfig from tensorrt_llm._torch.visual_gen.output import MediaOutput -from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline +from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline, ExtraParamSchema from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline from tensorrt_llm._torch.visual_gen.teacache import CacheContext from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor @@ -1040,8 +1040,63 @@ def _build_denoise_mask( # Inference # ------------------------------------------------------------------ + DEFAULT_GENERATION_PARAMS = { + "height": 512, + "width": 768, + "num_inference_steps": 40, + "guidance_scale": 4.0, + "max_sequence_length": 1024, + "num_frames": 121, + "frame_rate": 24.0, + "image_cond_strength": 1.0, + } + + EXTRA_PARAM_SPECS = { + "output_type": ExtraParamSchema( + type="str", + default="pt", + description="Output type: 'pt' for PyTorch tensors, 'pil' for PIL images.", + ), + "guidance_rescale": ExtraParamSchema( + type="float", + default=0.0, + description="Guidance rescale factor to prevent overexposure.", + ), + "stg_scale": ExtraParamSchema( + type="float", + default=0.0, + description="Spatiotemporal guidance scale for multi-modal guidance.", + ), + "stg_blocks": ExtraParamSchema( + type="list", + description="Transformer block indices for STG perturbation.", + ), + "modality_scale": ExtraParamSchema( + type="float", + default=1.0, + description="Modality guidance scale for multi-modal generation.", + ), + "rescale_scale": ExtraParamSchema( + type="float", + default=0.0, + range=(0.0, 1.0), + description="CFG rescale factor for multi-modal guidance.", + ), + "guidance_skip_step": ExtraParamSchema( + type="int", + default=0, + description="Number of initial denoising steps to skip guidance.", + ), + "enhance_prompt": ExtraParamSchema( + type="bool", + default=False, + description="Use Gemma3 LLM to enhance the prompt before generation.", + ), + } + def infer(self, req): """Run inference with request parameters.""" + extra = req.extra_params or {} return self.forward( prompt=req.prompt, negative_prompt=req.negative_prompt, @@ -1052,17 +1107,17 @@ def infer(self, req): num_inference_steps=req.num_inference_steps, guidance_scale=req.guidance_scale, seed=req.seed, - output_type=req.output_type, - guidance_rescale=req.guidance_rescale, + output_type=extra["output_type"], + guidance_rescale=extra["guidance_rescale"], max_sequence_length=req.max_sequence_length, - image=getattr(req, "image", None), - image_cond_strength=getattr(req, "image_cond_strength", 1.0), - stg_scale=getattr(req, "stg_scale", 0.0), - stg_blocks=getattr(req, "stg_blocks", None), - modality_scale=getattr(req, "modality_scale", 1.0), - rescale_scale=getattr(req, "rescale_scale", 0.0), - guidance_skip_step=getattr(req, "guidance_skip_step", 0), - enhance_prompt=getattr(req, "enhance_prompt", False), + image=req.image, + image_cond_strength=req.image_cond_strength, + stg_scale=extra["stg_scale"], + stg_blocks=extra["stg_blocks"], + modality_scale=extra["modality_scale"], + rescale_scale=extra["rescale_scale"], + guidance_skip_step=extra["guidance_skip_step"], + enhance_prompt=extra["enhance_prompt"], ) # ------------------------------------------------------------------ diff --git a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py index 8a83f25b61ad..77aee2a58ba7 100644 --- a/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py +++ b/tensorrt_llm/_torch/visual_gen/models/ltx2/pipeline_ltx2_two_stages.py @@ -564,6 +564,7 @@ def load_standard_components( # ------------------------------------------------------------------ def infer(self, req): + extra = req.extra_params or {} return self.forward( prompt=req.prompt, negative_prompt=req.negative_prompt, @@ -574,17 +575,17 @@ def infer(self, req): num_inference_steps=req.num_inference_steps, guidance_scale=req.guidance_scale, seed=req.seed, - output_type=req.output_type, - guidance_rescale=req.guidance_rescale, + output_type=extra["output_type"], + guidance_rescale=extra["guidance_rescale"], max_sequence_length=req.max_sequence_length, - image=getattr(req, "image", None), - image_cond_strength=getattr(req, "image_cond_strength", 1.0), - stg_scale=getattr(req, "stg_scale", 0.0), - stg_blocks=getattr(req, "stg_blocks", None), - modality_scale=getattr(req, "modality_scale", 1.0), - rescale_scale=getattr(req, "rescale_scale", 0.0), - guidance_skip_step=getattr(req, "guidance_skip_step", 0), - enhance_prompt=getattr(req, "enhance_prompt", False), + image=req.image, + image_cond_strength=req.image_cond_strength, + stg_scale=extra["stg_scale"], + stg_blocks=extra["stg_blocks"], + modality_scale=extra["modality_scale"], + rescale_scale=extra["rescale_scale"], + guidance_skip_step=extra["guidance_skip_step"], + enhance_prompt=extra["enhance_prompt"], ) # ------------------------------------------------------------------ diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py index 2788bedd489e..b4d42156fd72 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan.py @@ -10,7 +10,7 @@ from tensorrt_llm._torch.visual_gen.config import PipelineComponent from tensorrt_llm._torch.visual_gen.output import MediaOutput -from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline +from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline, ExtraParamSchema from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline from tensorrt_llm._torch.visual_gen.teacache import ExtractorConfig, register_extractor_from_config from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor @@ -301,8 +301,33 @@ def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> N max_sequence_length=512, ) + DEFAULT_GENERATION_PARAMS = { + "height": 480, + "width": 832, + "num_inference_steps": 50, + "guidance_scale": 5.0, + "max_sequence_length": 512, + "num_frames": 81, + "frame_rate": 24.0, + } + + EXTRA_PARAM_SPECS = { + "guidance_scale_2": ExtraParamSchema( + type="float", + default=None, + description="Second guidance scale for Wan 2.2 two-stage denoising.", + ), + "boundary_ratio": ExtraParamSchema( + type="float", + default=None, + range=(0.0, 1.0), + description="Timestep boundary ratio for switching guidance scales (Wan 2.2).", + ), + } + def infer(self, req): """Run inference with request parameters.""" + extra = req.extra_params or {} return self.forward( prompt=req.prompt, negative_prompt=req.negative_prompt, @@ -311,8 +336,8 @@ def infer(self, req): num_frames=req.num_frames, num_inference_steps=req.num_inference_steps, guidance_scale=req.guidance_scale, - guidance_scale_2=req.guidance_scale_2, - boundary_ratio=req.boundary_ratio, + guidance_scale_2=extra["guidance_scale_2"], + boundary_ratio=extra["boundary_ratio"], seed=req.seed, max_sequence_length=req.max_sequence_length, ) diff --git a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py index 95e937b55fca..8949fd502a76 100644 --- a/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py +++ b/tensorrt_llm/_torch/visual_gen/models/wan/pipeline_wan_i2v.py @@ -13,7 +13,7 @@ from tensorrt_llm._torch.visual_gen.config import PipelineComponent from tensorrt_llm._torch.visual_gen.output import MediaOutput -from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline +from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline, ExtraParamSchema from tensorrt_llm._torch.visual_gen.pipeline_registry import register_pipeline from tensorrt_llm._torch.visual_gen.teacache import ExtractorConfig, register_extractor_from_config from tensorrt_llm._torch.visual_gen.utils import postprocess_video_tensor @@ -363,6 +363,36 @@ def _run_warmup(self, height: int, width: int, num_frames: int, steps: int) -> N max_sequence_length=512, ) + DEFAULT_GENERATION_PARAMS = { + "height": 480, + "width": 832, + "num_inference_steps": 50, + "guidance_scale": 5.0, + "max_sequence_length": 512, + "num_frames": 81, + "frame_rate": 24.0, + "image_cond_strength": 1.0, + } + + EXTRA_PARAM_SPECS = { + "guidance_scale_2": ExtraParamSchema( + type="float", + default=None, + description="Second guidance scale for Wan 2.2 two-stage denoising.", + ), + "boundary_ratio": ExtraParamSchema( + type="float", + default=None, + range=(0.0, 1.0), + description="Timestep boundary ratio for switching guidance scales (Wan 2.2).", + ), + "last_image": ExtraParamSchema( + type="str", + default=None, + description="Last frame path for video interpolation (Wan I2V).", + ), + } + def infer(self, req): """Run inference with request parameters.""" # Extract image from request (can be path, PIL Image, or torch.Tensor) @@ -370,7 +400,8 @@ def infer(self, req): raise ValueError("I2V pipeline requires 'image' parameter") image = req.image[0] if isinstance(req.image, list) else req.image - last_image = req.last_image + extra = req.extra_params or {} + last_image = extra["last_image"] if last_image is not None and isinstance(last_image, list): last_image = last_image[0] if last_image else None @@ -384,8 +415,8 @@ def infer(self, req): num_frames=req.num_frames, num_inference_steps=req.num_inference_steps, guidance_scale=req.guidance_scale, - guidance_scale_2=req.guidance_scale_2, - boundary_ratio=req.boundary_ratio, + guidance_scale_2=extra["guidance_scale_2"], + boundary_ratio=extra["boundary_ratio"], seed=req.seed, max_sequence_length=req.max_sequence_length, last_image=last_image, diff --git a/tensorrt_llm/_torch/visual_gen/pipeline.py b/tensorrt_llm/_torch/visual_gen/pipeline.py index b564378a6fda..4b3a6e31bc2e 100644 --- a/tensorrt_llm/_torch/visual_gen/pipeline.py +++ b/tensorrt_llm/_torch/visual_gen/pipeline.py @@ -6,8 +6,10 @@ import torch import torch.distributed as dist import torch.nn as nn +from pydantic import Field from tensorrt_llm._utils import nvtx_range +from tensorrt_llm.llmapi.utils import StrictBaseModel from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping @@ -16,6 +18,22 @@ from .modules.vae.parallel_vae_interface import ParallelVAEFactory from .teacache import TeaCacheBackend + +class ExtraParamSchema(StrictBaseModel): + """Schema for a model-specific extra parameter. + + Returned by ``VisualGen.extra_param_specs`` so callers can + discover which ``extra_params`` keys are valid for the loaded pipeline. + """ + + type: str = Field(description="Python type name (e.g. 'float', 'int', 'bool').") + default: Any = Field(default=None, description="Default value used when omitted.") + description: str = Field(default="", description="Human-readable description.") + range: Optional[tuple] = Field( + default=None, description="Optional (min, max) range for numeric params." + ) + + if TYPE_CHECKING: from .config import DiffusionModelConfig @@ -219,6 +237,16 @@ def vae_adapter_class(self) -> Type[ParallelVAEFactory] | None: """Return the VAE adapter class for the pipeline.""" return None + #: Model-specific extra parameter specs. Subclasses override to declare + #: which ``extra_params`` keys they accept and their metadata. + #: Maps parameter names to ``ExtraParamSchema`` instances. + EXTRA_PARAM_SPECS: Dict[str, ExtraParamSchema] = {} + + #: Model-specific defaults for ``None`` fields in ``VisualGenParams``. + #: Keys should match ``DiffusionRequest`` field names. The executor + #: merges these into the request before calling ``infer()``. + DEFAULT_GENERATION_PARAMS: dict = {} + def infer(self, req: Any): raise NotImplementedError diff --git a/tensorrt_llm/bench/benchmark/visual_gen.py b/tensorrt_llm/bench/benchmark/visual_gen.py index 1a579df8dd9b..d3cadf4c940f 100644 --- a/tensorrt_llm/bench/benchmark/visual_gen.py +++ b/tensorrt_llm/bench/benchmark/visual_gen.py @@ -240,6 +240,8 @@ def visual_gen_command( gen_params_kwargs["num_inference_steps"] = num_inference_steps if guidance_scale is not None: gen_params_kwargs["guidance_scale"] = guidance_scale + if negative_prompt is not None: + gen_params_kwargs["negative_prompt"] = negative_prompt gen_params = VisualGenParams(**gen_params_kwargs) @@ -286,7 +288,6 @@ def visual_gen_command( visual_gen=visual_gen, input_requests=input_requests, gen_params=gen_params, - negative_prompt=negative_prompt, max_concurrency=max_concurrency, ) benchmark_duration = time.perf_counter() - benchmark_start @@ -329,7 +330,6 @@ def _run_benchmark( visual_gen, input_requests, gen_params, - negative_prompt: Optional[str], max_concurrency: int, ) -> list[VisualGenRequestOutput]: """Run the benchmark loop, dispatching requests with concurrency control.""" @@ -338,14 +338,13 @@ def _run_benchmark( outputs: list[VisualGenRequestOutput] = [] if max_concurrency <= 1: - outputs = _run_sequential(visual_gen, input_requests, gen_params, negative_prompt) + outputs = _run_sequential(visual_gen, input_requests, gen_params) else: outputs = asyncio.run( _run_concurrent( visual_gen, input_requests, gen_params, - negative_prompt, max_concurrency, ) ) @@ -353,22 +352,15 @@ def _run_benchmark( return outputs -def _run_sequential( - visual_gen, input_requests, gen_params, negative_prompt -) -> list[VisualGenRequestOutput]: +def _run_sequential(visual_gen, input_requests, gen_params) -> list[VisualGenRequestOutput]: """Run requests one at a time, measuring per-request latency.""" outputs = [] for req in input_requests: output = VisualGenRequestOutput() - inputs = ( - {"prompt": req.prompt, "negative_prompt": negative_prompt} - if negative_prompt - else req.prompt - ) st = time.perf_counter() try: - visual_gen.generate(inputs=inputs, params=gen_params) + visual_gen.generate(inputs=req.prompt, params=gen_params) output.e2e_latency = time.perf_counter() - st output.success = True except Exception as e: @@ -384,7 +376,7 @@ def _run_sequential( async def _run_concurrent( - visual_gen, input_requests, gen_params, negative_prompt, max_concurrency + visual_gen, input_requests, gen_params, max_concurrency ) -> list[VisualGenRequestOutput]: """Run requests concurrently using generate_async with a semaphore.""" import asyncio @@ -393,16 +385,11 @@ async def _run_concurrent( outputs: list[VisualGenRequestOutput] = [VisualGenRequestOutput() for _ in input_requests] async def _generate_one(idx, req): - inputs = ( - {"prompt": req.prompt, "negative_prompt": negative_prompt} - if negative_prompt - else req.prompt - ) async with semaphore: output = outputs[idx] st = time.perf_counter() try: - future = visual_gen.generate_async(inputs=inputs, params=gen_params) + future = visual_gen.generate_async(inputs=req.prompt, params=gen_params) await future.result() output.e2e_latency = time.perf_counter() - st output.success = True diff --git a/tensorrt_llm/serve/openai_disagg_server.py b/tensorrt_llm/serve/openai_disagg_server.py index 06095da4c0d3..ee1edbe5d00c 100644 --- a/tensorrt_llm/serve/openai_disagg_server.py +++ b/tensorrt_llm/serve/openai_disagg_server.py @@ -41,7 +41,8 @@ from tensorrt_llm.serve.openai_client import OpenAIClient, OpenAIHttpClient from tensorrt_llm.serve.openai_disagg_service import ( OpenAIDisaggregatedService, ResponseHooks) -from tensorrt_llm.serve.openai_protocol import (UCompletionRequest, +from tensorrt_llm.serve.openai_protocol import (DisaggregatedParams, + UCompletionRequest, UCompletionResponse) from tensorrt_llm.serve.perf_metrics import DisaggPerfMetricsCollector from tensorrt_llm.serve.responses_utils import (ServerArrivalTimeMiddleware, @@ -92,8 +93,8 @@ def __init__(self, self._metrics_interval_secs = metrics_interval_secs self._ctx_servers, self._gen_servers = get_ctx_gen_server_addrs(config.server_configs) - self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock) - self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock) + self._ctx_router = create_router(config.ctx_router_config, self._ctx_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock, disagg_node_id=config.node_id) + self._gen_router = create_router(config.gen_router_config, self._gen_servers, metadata_server_cfg, create_metadata_server(metadata_server_cfg), self._sync_server_clock, disagg_node_id=config.node_id) self._metadata_server = create_metadata_server(metadata_server_cfg) self._perf_metrics_collector = DisaggPerfMetricsCollector(config.perf_metrics_max_requests) @@ -157,6 +158,35 @@ def register_routes(self): if self._disagg_cluster_storage and isinstance(self._disagg_cluster_storage, HttpClusterStorageServer): self._disagg_cluster_storage.add_routes(self.app) + @staticmethod + def _extract_conversation_id(req: UCompletionRequest, raw_req: Request): + """Populate conversation_id from the X-Correlation-ID header. + + When not already set in the request body, copies the header value + into ``disaggregated_params.conversation_id``. + + aiperf sends multi-turn session IDs via the ``X-Correlation-ID`` + header (see aiperf ``base_transports.build_headers``). We mirror + that convention so the ConversationRouter can provide session + affinity without requiring clients to set the body field. + + When ``disaggregated_params`` is ``None`` (standard OpenAI + requests without disagg fields), a minimal instance is created + to carry the conversation_id. The service layer always rebuilds + ``disaggregated_params`` in ``_get_ctx_request`` / + ``_get_gen_request`` before forwarding to workers. + """ + header_conv_id = raw_req.headers.get("x-correlation-id") + if header_conv_id is None: + return + if req.disaggregated_params is None: + req.disaggregated_params = DisaggregatedParams( + request_type="context_only", + conversation_id=header_conv_id, + ) + elif req.disaggregated_params.conversation_id is None: + req.disaggregated_params.conversation_id = header_conv_id + def _wrap_entry_point(self, entry_point: Callable) -> Callable: async def wrapper(req: UCompletionRequest, raw_req: Request) -> Response: try: @@ -165,6 +195,7 @@ async def wrapper(req: UCompletionRequest, raw_req: Request) -> Response: self._perf_metrics_collector.stream_requests.inc() else: self._perf_metrics_collector.nonstream_requests.inc() + self._extract_conversation_id(req, raw_req) hooks = RawRequestResponseHooks(raw_req, self._perf_metrics_collector) response_or_generator = await entry_point(req, hooks) self._perf_metrics_collector.total_responses.inc() diff --git a/tensorrt_llm/serve/openai_disagg_service.py b/tensorrt_llm/serve/openai_disagg_service.py index 387d7fa03ec3..e77d9e07e178 100644 --- a/tensorrt_llm/serve/openai_disagg_service.py +++ b/tensorrt_llm/serve/openai_disagg_service.py @@ -144,6 +144,19 @@ async def _send_disagg_request_ctx_first( ) await self._verify_ctx_response(ctx_response) gen_req = self._get_gen_request(request, ctx_response, disagg_request_id) + else: + # Clear synthetic disaggregated_params that may have been + # injected by _extract_conversation_id (e.g. from the + # X-Correlation-ID header). When need_ctx=False the gen + # server handles full generation and must not see a stale + # request_type="context_only". + # _check_gen_only_disagg already sets proper generation_only + # params when applicable, so only clear the synthetic ones. + if ( + gen_req.disaggregated_params is not None + and gen_req.disaggregated_params.request_type == "context_only" + ): + gen_req.disaggregated_params = None if ctx_response is None or self._need_gen(ctx_response): if not gen_server: gen_server, _ = await self._gen_router.get_next_server( @@ -163,6 +176,12 @@ def _need_gen(self, response: UCompletionResponse) -> bool: return False return True + @staticmethod + def _get_conversation_id(request: UCompletionRequest) -> Optional[str]: + if request.disaggregated_params is not None: + return request.disaggregated_params.conversation_id + return None + def _get_ctx_request( self, request: UCompletionRequest, disagg_request_id: Optional[int] ) -> UCompletionRequest: @@ -172,6 +191,7 @@ def _get_ctx_request( request_type="context_only", disagg_request_id=disagg_request_id, schedule_style=self._schedule_style, + conversation_id=self._get_conversation_id(request), ), "stream": False, "stream_options": None, @@ -186,10 +206,12 @@ def _get_gen_request( disagg_request_id: Optional[int], ctx_server_info: Optional[dict] = None, ) -> UCompletionRequest: + conversation_id = self._get_conversation_id(request) if ctx_response: request.disaggregated_params = ctx_response.choices[0].disaggregated_params request.disaggregated_params.request_type = "generation_only" request.disaggregated_params.schedule_style = self._schedule_style + request.disaggregated_params.conversation_id = conversation_id # Replace the string prompt with prompt_tokens_ids if isinstance(request, CompletionRequest): request.prompt = ctx_response.prompt_token_ids @@ -202,6 +224,7 @@ def _get_gen_request( ctx_request_id=disagg_request_id, disagg_request_id=disagg_request_id, schedule_style=self._schedule_style, + conversation_id=conversation_id, ) if ctx_server_info and "server_info" in ctx_server_info: disaggregated_params = ctx_server_info["server_info"].get("disaggregated_params", {}) diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index 9d6b1fcad608..8730976898fd 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -127,6 +127,7 @@ class DisaggregatedParams(OpenAIBaseModel): ctx_dp_rank: Optional[int] = None ctx_info_endpoint: Optional[str] = None schedule_style: Optional[DisaggScheduleStyle] = None + conversation_id: Optional[str] = None class ErrorResponse(OpenAIBaseModel): diff --git a/tensorrt_llm/serve/router.py b/tensorrt_llm/serve/router.py index 53685dbc0ad3..a1dbb027d693 100644 --- a/tensorrt_llm/serve/router.py +++ b/tensorrt_llm/serve/router.py @@ -1,7 +1,8 @@ import asyncio -import heapq import os +import time from abc import ABC, abstractmethod +from collections import OrderedDict from typing import Awaitable, Callable, Dict, Iterable, List, Optional, Union import aiohttp @@ -144,6 +145,70 @@ def num_active_requests(self): return self._num_active_requests +class LoadBalancingMixin: + """Mixin providing common server state and request tracking. + + Subclasses should set ``_server_state_class`` and call + ``_init_load_balancing()`` in ``__init__``. + """ + + _server_state_class: type = ServerState + + def _init_load_balancing(self, + servers: Optional[List[str]], + use_tokens: bool = False): + self._use_tokens = use_tokens + self._server_state: dict[str, ServerState] = {} + self._req_routing_table: dict[int, str] = {} + self._rr_counter = 0 + for server in servers or []: + self._server_state[server] = self._create_server_state(server) + + def _create_server_state(self, server: str) -> ServerState: + return self._server_state_class(server, self._use_tokens) + + def _get_server_load(self, server: str) -> int: + state = self._server_state[server] + return state._num_active_tokens if self._use_tokens \ + else state._num_active_requests + + def _validate_servers_available(self): + if not self._servers: + if self._metadata_server: + raise ValueError( + f"No {self._server_role} servers available in metadata service" + ) + else: + raise ValueError(f"No {self._server_role} servers available") + + async def _register_request(self, server: str, request: OpenAIRequest): + await self._server_state[server].increment_load(request) + self._req_routing_table[id(request)] = server + + async def _unregister_request(self, request: OpenAIRequest, + **kwargs) -> str: + server = self._req_routing_table.pop(id(request)) + if server in self._server_state: + await self._server_state[server].decrement_load(request, **kwargs) + return server + + def _select_least_loaded(self, + exclude_server: Optional[str] = None + ) -> Optional[str]: + """Pick the server with the lowest load. Round-robin breaks ties.""" + candidates = [s for s in self._server_state if s != exclude_server] + if not candidates: + return None + loads = {s: self._get_server_load(s) for s in candidates} + min_load = min(loads.values()) + tied = [s for s in candidates if loads[s] == min_load] + server = tied[self._rr_counter % len(tied)] + self._rr_counter += 1 + logger.debug(f"LoadBalancingMixin: selected={server}, " + f"loads={loads}, tied={tied}, rr={self._rr_counter - 1}") + return server + + class Router(ABC): def __init__( @@ -168,8 +233,11 @@ def __init__( @abstractmethod def _on_servers_updated(self, old_servers, new_servers): - """Called when the server list changes. Override in subclasses to handle index resets. + """Called when the server list changes. + + Override in subclasses to handle index resets. Called with lock already held. + Args: old_servers: The previous server list new_servers: The new server list @@ -512,7 +580,7 @@ async def finish_request(self, request: OpenAIRequest): pass -class LoadBalancingRouter(Router): +class LoadBalancingRouter(LoadBalancingMixin, Router): def __init__(self, server_role: ServerRole, @@ -523,89 +591,33 @@ def __init__(self, **kwargs): super().__init__(server_role, servers, metadata_server_cfg, metadata_server, **kwargs) - # Load map between servers and their number of tokens processed - self._server_state = {} - self._server_load_heap = [] - - # Routing table to map requests to servers - self._req_routing_table = {} - - self._use_tokens = use_tokens - self._init_heap() + self._init_load_balancing(servers, use_tokens) def _on_servers_updated(self, old_servers, new_servers): - """Rebuild the heap when the server list changes.""" - # Keep the state for servers that still exist - current_state = {} + new_state = {} for server in new_servers: - if server in self._server_state: - # Keep existing state - current_state[server] = self._server_state[server] - else: - # Initialize new server state - current_state[server] = ServerState(server, self._use_tokens) - - # Update state and rebuild heap - self._server_state = current_state - self._server_load_heap = [] - for server in new_servers: - heapq.heappush(self._server_load_heap, - (self._get_server_load(server), server)) - - def _init_heap(self): - for server in self._servers: - self._server_state[server] = ServerState(server, self._use_tokens) - heapq.heappush(self._server_load_heap, - (self._get_server_load(server), server)) + new_state[server] = (self._server_state.get(server) + or self._create_server_state(server)) + self._server_state = new_state async def get_next_server( self, request: OpenAIRequest, exclude_server: Optional[str] = None) -> tuple[str, dict]: - if not self._servers: - if self._metadata_server: - raise ValueError( - f"No {self._server_role} servers available in metadata service" - ) - else: - raise ValueError(f"No {self._server_role} servers available") + self._validate_servers_available() async with self._lock: - if exclude_server: - server_load_heap = [(self._get_server_load(server), server) - for server in self._servers - if server != exclude_server] - heapq.heapify(server_load_heap) - else: - server_load_heap = self._server_load_heap - - server = heapq.heappop(server_load_heap)[1] - await self._server_state[server].increment_load(request) - # maintain the member heap - if exclude_server: - self._server_load_heap = server_load_heap - if exclude_server in self._server_state: - heapq.heappush( - self._server_load_heap, - (self._get_server_load(exclude_server), exclude_server)) - heapq.heappush(self._server_load_heap, - (self._get_server_load(server), server)) - - self._req_routing_table[id(request)] = server + server = self._select_least_loaded(exclude_server) + if server is None: + raise ValueError( + f"No available servers after excluding {exclude_server}") + await self._register_request(server, request) return server, {"server_info": self._server_info.get(server, {})} - def _get_server_load(self, server): - return self._server_state[server]._num_active_tokens if self._use_tokens \ - else self._server_state[server]._num_active_requests - async def finish_request(self, request: OpenAIRequest): async with self._lock: - server = self._req_routing_table[id(request)] - await self._server_state[server].decrement_load(request) - heapq.heappush(self._server_load_heap, - (self._get_server_load(server), server)) - del self._req_routing_table[id(request)] + await self._unregister_request(request) def block_key_hasher(token_ids: list[int], @@ -615,45 +627,24 @@ def block_key_hasher(token_ids: list[int], 0 if parent_hash is None else parent_hash) -class KvCacheAwareRouter(Router): - - def __init__(self, - server_role: ServerRole = None, - servers: list[str] = None, - metadata_server_cfg: MetadataServerConfig = None, - metadata_server: JsonDictionary = None, - use_tokens: bool = False, - max_batch_size: int = 64, - tokens_per_block: int = 32, - **kwargs): - super().__init__(server_role, servers, metadata_server_cfg, - metadata_server, **kwargs) - self._lock = asyncio.Lock() - self._use_tokens = use_tokens +class BlockHashMixin: + """Shared tokenization and block-hash computation. - # Load map between servers and their number of tokens processed - self._server_state: dict[str, KvCacheAwareServerState] = { - server: KvCacheAwareServerState(server, use_tokens) - for server in servers or [] - } - - # Routing table to map requests to servers - self._req_routing_table: dict[int, OpenAIRequest] = {} + Used by routers that need KV-cache-aware prefix matching. + """ - self._tokenizers = {} - # TODO: use max_num_tokens? per server? - self._max_batch_size = max_batch_size + def _init_block_hashing(self, tokens_per_block: int = 32): env_tokens_per_block = os.environ.get( "TRTLLM_KVCACHE_AWARE_ROUTER_HASH_TOKENS_PER_BLOCK") if env_tokens_per_block is not None: tokens_per_block = int(env_tokens_per_block) self._tokens_per_block = tokens_per_block - logger.info( - f"KvCacheAwareRouter: tokens_per_block={self._tokens_per_block}") + self._tokenizers: dict = {} def _get_tokenizer(self, model: str): if model not in self._tokenizers: - self._tokenizers[model] = AutoTokenizer.from_pretrained(model) + self._tokenizers[model] = AutoTokenizer.from_pretrained( + model, trust_remote_code=True) return self._tokenizers[model] def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: @@ -689,29 +680,69 @@ def _tokenize(self, request: OpenAIRequest) -> list[list[int]]: token_lists = [tokenizer(prompt)["input_ids"] for prompt in prompts] # Replace string prompts with token IDs so the worker server # skips re-tokenization - request.prompt = token_lists if len(token_lists) > 1 else token_lists[0] + request.prompt = (token_lists + if len(token_lists) > 1 else token_lists[0]) return token_lists - async def get_next_server( - self, - request: OpenAIRequest, - exclude_server: Optional[str] = None) -> tuple[str, dict]: - async with self._lock: - servers = list([ - server for server in self._server_state.keys() - if server != exclude_server - ]) - token_lists = self._tokenize(request) + def _compute_block_hashes(self, + token_lists: list[list[int]]) -> list[list[int]]: block_hashes: list[list[int]] = [] for token_list in token_lists: hash_list = [] - # in KvCacheManager, the last token is not included in the block key + # in KvCacheManager, the last token is not included in the + # block key for t in range(0, len(token_list) - 1, self._tokens_per_block): t_end = min(t + self._tokens_per_block, len(token_list) - 1) hash_list.append( block_key_hasher(token_list[t:t_end], None if t == 0 else hash_list[-1])) block_hashes.append(hash_list) + return block_hashes + + @staticmethod + def _text_to_int_sequences(texts: list[str]) -> list[list[int]]: + """Convert text strings to lists of unicode code points. + + Usable as input to ``_compute_block_hashes``. + """ + return [[ord(c) for c in text] for text in texts] + + +class KvCacheAwareRouter(BlockHashMixin, LoadBalancingMixin, Router): + + _server_state_class = KvCacheAwareServerState + + def __init__(self, + server_role: ServerRole = None, + servers: list[str] = None, + metadata_server_cfg: MetadataServerConfig = None, + metadata_server: JsonDictionary = None, + use_tokens: bool = False, + max_batch_size: int = 64, + tokens_per_block: int = 32, + **kwargs): + super().__init__(server_role, servers, metadata_server_cfg, + metadata_server, **kwargs) + self._init_block_hashing(tokens_per_block) + self._init_load_balancing(servers, use_tokens) + # TODO: use max_num_tokens? per server? + self._max_batch_size = max_batch_size + + def _create_server_state(self, server): + return KvCacheAwareServerState(server, self._use_tokens, + self._tokens_per_block) + + async def get_next_server( + self, + request: OpenAIRequest, + exclude_server: Optional[str] = None) -> tuple[str, dict]: + async with self._lock: + servers = [ + server for server in self._server_state.keys() + if server != exclude_server + ] + token_lists = self._tokenize(request) + block_hashes = self._compute_block_hashes(token_lists) padded_tokens = sum( len(hash_list) for hash_list in block_hashes) * self._tokens_per_block @@ -731,10 +762,13 @@ async def get_next_server( score = matches[-1] / padded_tokens - workloads[ i] / self._max_batch_size scores.append(score) - server = servers[scores.index(max(scores))] + max_score = max(scores) + tied = [i for i, s in enumerate(scores) if s == max_score] + winner = tied[self._rr_counter % len(tied)] + self._rr_counter += 1 + server = servers[winner] async with self._lock: - await self._server_state[server].increment_load(request) - self._req_routing_table[id(request)] = server + await self._register_request(server, request) return server, { "block_hashes": block_hashes, # list[list[int]] "token_lists": token_lists, # list[list[int]] @@ -746,18 +780,469 @@ async def finish_request(self, request: OpenAIRequest, session: Optional[aiohttp.ClientSession] = None): async with self._lock: - server = self._req_routing_table[id(request)] - del self._req_routing_table[id(request)] - if server in self._server_state: - await self._server_state[server].decrement_load(request, - session=session) + await self._unregister_request(request, session=session) def _on_servers_updated(self, old_servers, new_servers): - for new_server in new_servers: - self._server_state[new_server] = KvCacheAwareServerState( - new_server, self._use_tokens) - for old_server in old_servers: - self._server_state.pop(old_server, None) + new_state = {} + for server in new_servers: + new_state[server] = (self._server_state.get(server) + or self._create_server_state(server)) + self._server_state = new_state + + +class _BlockHashTrie: + """Prefix tree mapping block-hash sequences to session IDs. + + Each session ID is stored at every node along its hash path so that + partial prefix matches are discovered in O(L) time (L = query length). + """ + + class _Node: + __slots__ = ('children', 'session_ids') + + def __init__(self): + self.children: dict[int, '_BlockHashTrie._Node'] = {} + self.session_ids: set[str] = set() + + def __init__(self): + self._root = self._Node() + + def insert(self, session_id: str, block_hashes: list[int]): + """Register *session_id* at every node along *block_hashes*.""" + node = self._root + for h in block_hashes: + if h not in node.children: + node.children[h] = self._Node() + node = node.children[h] + node.session_ids.add(session_id) + + def remove(self, session_id: str, block_hashes: list[int]): + """Remove *session_id* from its hash path and prune empty nodes.""" + node = self._root + path = [] # list of (parent_node, hash_key) + for h in block_hashes: + if h not in node.children: + break + path.append((node, h)) + node = node.children[h] + node.session_ids.discard(session_id) + # Prune empty leaf nodes bottom-up + for parent, key in reversed(path): + child = parent.children[key] + if not child.session_ids and not child.children: + del parent.children[key] + else: + break + + def find_longest_prefix_match( + self, + block_hashes: list[int], + valid_fn: Optional[Callable[[str], bool]] = None, + ) -> tuple[Optional[str], int]: + """Return ``(session_id, match_depth)`` for the deepest valid match. + + Returns ``(None, 0)`` when no valid session matches. + """ + node = self._root + best_id: Optional[str] = None + best_depth = 0 + for depth, h in enumerate(block_hashes, 1): + if h not in node.children: + break + node = node.children[h] + for sid in node.session_ids: + if valid_fn is None or valid_fn(sid): + best_id = sid + best_depth = depth + break + return best_id, best_depth + + +class ConversationRouter(BlockHashMixin, LoadBalancingMixin, Router): + """Router that provides session affinity for multi-turn conversations. + + Routing priority: + 1. Explicit ``conversation_id`` in ``disaggregated_params`` — sticky + routing to the previously assigned server. + 2. Implicit block-hash prefix matching — find the session whose + stored block hashes share the longest prefix with the new request. + If the match ratio exceeds ``match_threshold`` the request is + treated as a continuation. + 3. Fallback — least-loaded server (load-balancing). + + Args: + use_token_ids: When ``True``, tokenize text requests with a + real tokenizer (same hashing as ``KvCacheAwareRouter``). + When ``False`` (default), convert raw text to unicode + code-point sequences for hashing. Pre-existing token IDs + in the request are always used regardless of this flag. + hash_skip_count: Number of leading tokens or code-points to + skip before computing block hashes. Set this to the + approximate length of a shared system prompt so that every + request does not trivially prefix-match on the common + preamble. + """ + + CHAR_PER_TOKEN = 5 # approximately 4 characters per token + 1 space + + def __init__(self, + server_role: ServerRole, + servers: List[str] = None, + metadata_server_cfg: MetadataServerConfig = None, + metadata_server: JsonDictionary = None, + match_threshold: float = 0.75, + tokens_per_block: int = 128, + use_token_ids: bool = False, + hash_skip_count: int = 0, + max_sessions: int = 100000, + **kwargs): + super().__init__(server_role, servers, metadata_server_cfg, + metadata_server, **kwargs) + self._init_load_balancing(servers) + self._init_block_hashing(tokens_per_block) + + self._match_threshold = match_threshold + self._use_token_ids = use_token_ids + self._hash_skip_count = hash_skip_count + self._max_sessions = max_sessions + self._disagg_node_id = kwargs.get("disagg_node_id", 0) + + # conversation_id -> (server, block_hashes) LRU-ordered + self._session_table: OrderedDict[str, + tuple[str, + list[int]]] = (OrderedDict()) + # Prefix tree for O(L) block-hash matching + self._hash_trie = _BlockHashTrie() + # server -> set of conversation_ids (reverse index) + self._server_sessions: dict[str, set[str]] = { + s: set() + for s in (servers or []) + } + self._implicit_id_counter = 0 + + # In-flight content-load tracking: estimated tokens currently + # being processed on each server. Incremented on assignment, + # decremented on finish. When loads are equal, round-robin + # breaks ties to ensure balanced assignment. + self._server_content_load: dict[str, int] = { + s: 0 + for s in (servers or []) + } + # id(request) -> (server, weight, monotonic_timestamp) + self._req_content_entry: dict[int, tuple[str, int, float]] = {} + + # ── content-based load tracking ── + + def _estimate_content_weight( + self, + request: OpenAIRequest, + block_hashes: Optional[list[int]] = None) -> int: + """Estimate request weight in tokens without tokenization. + + When *block_hashes* are available (IMPLICIT / FALLBACK paths), + uses ``len(block_hashes) * tokens_per_block``. Otherwise + estimates from text character length. + """ + if block_hashes is not None: + return len(block_hashes) * self._tokens_per_block + text = self._extract_text(request) + return max(len(text) // self.CHAR_PER_TOKEN, 1) + + def _add_content_load(self, server: str, request: OpenAIRequest, + weight: int): + self._server_content_load[server] = ( + self._server_content_load.get(server, 0) + weight) + self._req_content_entry[id(request)] = (server, weight, + time.monotonic()) + + def _remove_content_load(self, server: str, request: OpenAIRequest): + entry = self._req_content_entry.pop(id(request), None) + if entry is not None: + _, weight, _ = entry + self._server_content_load[server] = max( + self._server_content_load.get(server, 0) - weight, 0) + + def _get_content_load(self, server: str) -> int: + return self._server_content_load.get(server, 0) + + def _get_server_load(self, server: str) -> int: + """Use content weight so ``_select_least_loaded`` balances by + estimated tokens rather than request count. + """ + return self._get_content_load(server) + + def _on_servers_updated(self, old_servers, new_servers): + """Rebuild reverse index and evict stale sessions. + + Also syncs ``LoadBalancingMixin._server_state`` so that + ``_select_least_loaded`` stays consistent with the live server list. + """ + # Sync load-balancer state (same pattern as RoundRobinRouter). + new_state = {} + for server in new_servers: + new_state[server] = (self._server_state.get(server) + or self._create_server_state(server)) + self._server_state = new_state + + new_server_sessions: dict[str, set[str]] = {} + for server in new_servers: + new_server_sessions[server] = self._server_sessions.get( + server, set()) + if server not in self._server_content_load: + self._server_content_load[server] = 0 + + # Evict sessions pointing to removed servers + removed_servers = set(old_servers) - set(new_servers) + for removed in removed_servers: + for conv_id in list(self._server_sessions.get(removed, ())): + entry = self._session_table.pop(conv_id, None) + if entry is not None: + self._hash_trie.remove(conv_id, entry[1]) + self._server_content_load.pop(removed, None) + + self._server_sessions = new_server_sessions + + # ── text extraction & block-hash prefix matching ── + + @staticmethod + def _extract_text(request: OpenAIRequest) -> str: + """Return a canonical text representation of the request content.""" + if isinstance(request, ChatCompletionRequest): + parts = [] + for msg in request.messages: + m = msg if isinstance(msg, dict) else dict(msg) + parts.append(f"{m.get('role', '')}:{m.get('content', '')}") + return "\n".join(parts) + + # CompletionRequest + prompt = request.prompt + if isinstance(prompt, str): + return prompt + if isinstance(prompt, list): + if prompt and isinstance(prompt[0], str): + return "\n".join(prompt) + return str(prompt) + return str(prompt) + + @staticmethod + def _try_extract_token_ids( + request: OpenAIRequest) -> Optional[list[list[int]]]: + """Return pre-existing token-ID lists from the request. + + Returns ``None`` when the request does not already carry them. + """ + if isinstance(request, ChatCompletionRequest): + if request.prompt_token_ids is not None: + return [request.prompt_token_ids] + return None + + # CompletionRequest + prompt = request.prompt + if isinstance(prompt, list): + if prompt and isinstance(prompt[0], list): + return prompt + if prompt and isinstance(prompt[0], int): + return [prompt] + return None + + def _request_to_block_hashes(self, request: OpenAIRequest) -> list[int]: + """Compute block hashes for *request*. + + Resolution order: + 1. Pre-existing token IDs in the request → use directly. + 2. ``use_token_ids=True`` → tokenize text via ``_tokenize()``. + 3. Fallback → convert raw text to unicode code-point sequences. + + When ``hash_skip_count > 0`` the first *hash_skip_count* + elements (tokens or code-points) are stripped before hashing, + which is useful for ignoring a shared system prompt that would + otherwise cause every request to prefix-match. + """ + token_ids = self._try_extract_token_ids(request) + if token_ids is not None: + int_sequences = token_ids + skip_count = self._hash_skip_count + elif self._use_token_ids: + int_sequences = self._tokenize(request) + skip_count = self._hash_skip_count + else: + text = self._extract_text(request) + int_sequences = self._text_to_int_sequences([text]) + skip_count = self._hash_skip_count * self.CHAR_PER_TOKEN + + if skip_count > 0: + int_sequences = [seq[skip_count:] for seq in int_sequences] + + return self._compute_block_hashes(int_sequences)[0] + + def _find_matching_session(self, block_hashes: list[int], + exclude_server: Optional[str]) -> Optional[str]: + """Find the session with the longest matching block-hash prefix. + + Uses ``_hash_trie`` for O(L) lookup. Returns ``None`` when no + session meets the match-ratio threshold. + """ + if not block_hashes: + return None + + def _valid(conv_id: str) -> bool: + entry = self._session_table.get(conv_id) + if entry is None: + return False + server = entry[0] + return (server in self._server_state and server != exclude_server) + + best_conv_id, best_depth = self._hash_trie.find_longest_prefix_match( + block_hashes, _valid) + + if best_conv_id is None: + return None + ratio = best_depth / len(block_hashes) + if ratio >= self._match_threshold: + best_server = self._session_table[best_conv_id][0] + logger.debug( + f"ConversationRouter: implicit match conv_id={best_conv_id}, " + f"server={best_server}, match_ratio={ratio:.3f} " + f"({best_depth}/{len(block_hashes)} blocks)") + return best_conv_id + return None + + # ── routing helpers ── + + def _get_conversation_id(self, request: OpenAIRequest) -> Optional[str]: + if request.disaggregated_params is not None: + return request.disaggregated_params.conversation_id + return None + + def _generate_implicit_id(self) -> str: + self._implicit_id_counter += 1 + return f"conv_id:{self._disagg_node_id}_{self._implicit_id_counter}" + + def _update_session(self, conv_id: str, server: str, + block_hashes: list[int]): + old = self._session_table.get(conv_id) + if old is not None: + old_server, old_hashes = old + if old_server in self._server_sessions: + self._server_sessions[old_server].discard(conv_id) + self._hash_trie.remove(conv_id, old_hashes) + self._session_table[conv_id] = (server, block_hashes) + self._session_table.move_to_end(conv_id) + self._hash_trie.insert(conv_id, block_hashes) + if server in self._server_sessions: + self._server_sessions[server].add(conv_id) + # LRU eviction when over capacity + while len(self._session_table) > self._max_sessions: + self._evict_oldest_session() + + def _evict_oldest_session(self): + """Remove the least-recently-used session from all indices.""" + conv_id, (server, hashes) = self._session_table.popitem(last=False) + self._hash_trie.remove(conv_id, hashes) + if server in self._server_sessions: + self._server_sessions[server].discard(conv_id) + + # ── public interface ── + + async def get_next_server( + self, + request: OpenAIRequest, + exclude_server: Optional[str] = None) -> tuple[str, dict]: + self._validate_servers_available() + + # Pre-compute outside the lock (tokenization + hashing) + conv_id = self._get_conversation_id(request) + block_hashes = self._request_to_block_hashes(request) + weight = self._estimate_content_weight(request, block_hashes) + + async with self._lock: + + # 1. Explicit conversation_id — sticky routing. + # Always honour session affinity when the server is alive + # and not explicitly excluded. No overload gate — the + # server itself provides backpressure. + if conv_id and conv_id in self._session_table: + sticky_server, _ = self._session_table[conv_id] + if sticky_server not in self._server_state: + logger.debug( + f"ConversationRouter: STICKY MISS conv_id={conv_id} " + f"-> server={sticky_server} NOT in server_state, " + f"falling through to FALLBACK") + elif sticky_server == exclude_server: + logger.debug( + f"ConversationRouter: STICKY MISS conv_id={conv_id} " + f"-> server={sticky_server} is exclude_server") + else: + self._update_session(conv_id, sticky_server, block_hashes) + await self._register_request(sticky_server, request) + self._add_content_load(sticky_server, request, weight) + loads = { + s: self._get_content_load(s) + for s in self._servers + } + logger.debug( + f"ConversationRouter: STICKY conv_id={conv_id} " + f"-> server={sticky_server}, " + f"content_loads={loads}, weight={weight}") + return sticky_server, { + "server_info": self._server_info.get(sticky_server, {}) + } + elif conv_id: + logger.debug(f"ConversationRouter: NEW conv_id={conv_id} " + f"not in session_table " + f"(size={len(self._session_table)})") + + # 2. Implicit block-hash prefix matching. + # Always honour match when the server is alive. + matched_id = None + if not conv_id: + matched_id = self._find_matching_session( + block_hashes, exclude_server) + if matched_id is not None: + sticky_server, _ = self._session_table[matched_id] + self._update_session(matched_id, sticky_server, + block_hashes) + await self._register_request(sticky_server, request) + self._add_content_load(sticky_server, request, weight) + loads = { + s: self._get_content_load(s) + for s in self._servers + } + logger.debug( + f"ConversationRouter: IMPLICIT match " + f"conv_id={matched_id} -> server={sticky_server}, " + f"content_loads={loads}, weight={weight}") + return sticky_server, { + "server_info": self._server_info.get(sticky_server, {}) + } + + # 3. Fallback — least-loaded server for new sessions or + # sessions whose sticky server is unavailable. + server = self._select_least_loaded(exclude_server) + if server is None: + raise ValueError( + f"No available servers after excluding {exclude_server}") + await self._register_request(server, request) + self._add_content_load(server, request, weight) + + # Store session mapping. + if not conv_id: + conv_id = self._generate_implicit_id() + self._update_session(conv_id, server, block_hashes) + loads = {s: self._get_content_load(s) for s in self._servers} + logger.debug( + f"ConversationRouter: FALLBACK conv_id={conv_id} " + f"-> server={server}, content_loads={loads}, weight={weight}") + + return server, {"server_info": self._server_info.get(server, {})} + + async def finish_request(self, request: OpenAIRequest): + async with self._lock: + server = await self._unregister_request(request) + self._remove_content_load(server, request) + loads = {s: self._get_content_load(s) for s in self._servers} + logger.debug(f"ConversationRouter: FINISH server={server}, " + f"content_loads={loads}") def create_router( @@ -765,7 +1250,8 @@ def create_router( servers: Optional[List[str]], metadata_server_cfg: Optional[MetadataServerConfig] = None, metadata_server: Optional[JsonDictionary] = None, - server_preparation_func: Optional[Callable[[str], Awaitable[None]]] = None + server_preparation_func: Optional[Callable[[str], Awaitable[None]]] = None, + disagg_node_id: int = 0, ) -> Router: """ Factory function to create different types of router instances. @@ -787,6 +1273,7 @@ def create_router( "round_robin": RoundRobinRouter, "load_balancing": LoadBalancingRouter, "kv_cache_aware": KvCacheAwareRouter, + "conversation": ConversationRouter, } router_type = router_config.type if router_config else "round_robin" router_class = router_map.get(router_type.lower()) @@ -795,6 +1282,7 @@ def create_router( raise ValueError(f"Unsupported router type: {router_type}. " f"Supported types are: {list(router_map.keys())}") extra_args = router_config.args if router_config else {} + extra_args["disagg_node_id"] = disagg_node_id return router_class(router_config.server_role if router_config else None, servers, diff --git a/tensorrt_llm/serve/visual_gen_utils.py b/tensorrt_llm/serve/visual_gen_utils.py index b4d16cc37239..27aacf33766a 100644 --- a/tensorrt_llm/serve/visual_gen_utils.py +++ b/tensorrt_llm/serve/visual_gen_utils.py @@ -17,57 +17,61 @@ def parse_visual_gen_params( id: str, media_storage_path: Optional[str] = None, ) -> VisualGenParams: - params = VisualGenParams() - params.prompt = request.prompt - if request.negative_prompt is not None: - params.negative_prompt = request.negative_prompt + kwargs: Dict[str, Any] = {} + extra: Dict[str, Any] = {} + + kwargs["negative_prompt"] = request.negative_prompt if request.size is not None and request.size != "auto": - params.width, params.height = map(int, request.size.split("x")) + kwargs["width"], kwargs["height"] = map(int, request.size.split("x")) if request.guidance_scale is not None: - params.guidance_scale = request.guidance_scale + kwargs["guidance_scale"] = request.guidance_scale if request.guidance_rescale is not None: - params.guidance_rescale = request.guidance_rescale + extra["guidance_rescale"] = request.guidance_rescale - if isinstance(request, ImageGenerationRequest) or isinstance(request, ImageEditRequest): + if isinstance(request, (ImageGenerationRequest, ImageEditRequest)): if request.num_inference_steps is not None: - params.num_inference_steps = request.num_inference_steps + kwargs["num_inference_steps"] = request.num_inference_steps elif isinstance(request, ImageGenerationRequest) and request.quality == "hd": - params.num_inference_steps = 30 + kwargs["num_inference_steps"] = 30 if request.n is not None: - params.num_images_per_prompt = request.n + kwargs["num_images_per_prompt"] = request.n if isinstance(request, ImageEditRequest): if request.image is not None: if isinstance(request.image, list): - params.image = [base64.b64decode(image) for image in request.image] + kwargs["image"] = [base64.b64decode(image) for image in request.image] else: - params.image = [base64.b64decode(request.image)] + kwargs["image"] = [base64.b64decode(request.image)] if request.mask is not None: if isinstance(request.mask, list): - params.mask = [base64.b64decode(mask) for mask in request.mask] + kwargs["mask"] = [base64.b64decode(mask) for mask in request.mask] else: - params.mask = base64.b64decode(request.mask) + kwargs["mask"] = base64.b64decode(request.mask) elif isinstance(request, VideoGenerationRequest): if request.num_inference_steps is not None: - params.num_inference_steps = request.num_inference_steps + kwargs["num_inference_steps"] = request.num_inference_steps if request.input_reference is not None: if media_storage_path is None: raise ValueError("media_storage_path is required when input_reference is provided") - params.input_reference = os.path.join(media_storage_path, f"{id}_reference.png") + ref_path = os.path.join(media_storage_path, f"{id}_reference.png") if isinstance(request.input_reference, str): - with open(params.input_reference, "wb") as f: + with open(ref_path, "wb") as f: f.write(base64.b64decode(request.input_reference)) else: - with open(params.input_reference, "wb") as f: + with open(ref_path, "wb") as f: shutil.copyfileobj(request.input_reference.file, f) + kwargs["image"] = ref_path - params.frame_rate = request.fps - params.num_frames = int(request.seconds * request.fps) + kwargs["frame_rate"] = request.fps + kwargs["num_frames"] = int(request.seconds * request.fps) if request.seed is not None: - params.seed = int(request.seed) + kwargs["seed"] = int(request.seed) + + if extra: + kwargs["extra_params"] = extra - return params + return VisualGenParams(**kwargs) class AsyncDictStore: diff --git a/tensorrt_llm/visual_gen/__init__.py b/tensorrt_llm/visual_gen/__init__.py index 85421a1abb0c..8c346734b137 100644 --- a/tensorrt_llm/visual_gen/__init__.py +++ b/tensorrt_llm/visual_gen/__init__.py @@ -13,12 +13,22 @@ # See the License for the specific language governing permissions and # limitations under the License. from .args import VisualGenArgs -from .visual_gen import MediaOutput, VisualGen, VisualGenError, VisualGenParams, VisualGenResult +from .params import VisualGenParams +from .visual_gen import ( + ExtraParamSchema, + MediaOutput, + VisualGen, + VisualGenError, + VisualGenParamsError, + VisualGenResult, +) __all__ = [ "VisualGen", "VisualGenArgs", + "ExtraParamSchema", "VisualGenError", + "VisualGenParamsError", "VisualGenParams", "VisualGenResult", "MediaOutput", diff --git a/tensorrt_llm/visual_gen/params.py b/tensorrt_llm/visual_gen/params.py new file mode 100644 index 000000000000..87754a56c956 --- /dev/null +++ b/tensorrt_llm/visual_gen/params.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, List, Optional, Union + +from pydantic import Field + +from tensorrt_llm.llmapi.utils import StrictBaseModel, set_api_status + + +@set_api_status("prototype") +class VisualGenParams(StrictBaseModel): + """Parameters for visual generation. + + Fields default to ``None``, meaning "use the model's default". + Per-model defaults are declared by each pipeline via + ``DEFAULT_GENERATION_PARAMS`` and merged automatically before + inference. + + Model-specific parameters (e.g. LTX-2's ``stg_scale``, Wan's + ``guidance_scale_2``) should be passed via ``extra_params``. + Use ``VisualGen.extra_param_specs`` to discover valid keys + for the loaded pipeline. + """ + + # Core — None means "use model default" + height: Optional[int] = Field(default=None, description="Output height in pixels.") + width: Optional[int] = Field(default=None, description="Output width in pixels.") + num_inference_steps: Optional[int] = Field( + default=None, description="Number of denoising steps." + ) + guidance_scale: Optional[float] = Field( + default=None, description="Classifier-free guidance scale." + ) + max_sequence_length: Optional[int] = Field( + default=None, description="Max tokens for text encoding." + ) + seed: int = Field(default=42, description="Random seed for reproducibility.") + + # Video + num_frames: Optional[int] = Field( + default=None, description="Number of frames. None = model default." + ) + frame_rate: Optional[float] = Field(default=None, description="Video frame rate in fps.") + + # Conditioning inputs + negative_prompt: Optional[str] = Field(default=None, description="Negative prompt for CFG.") + image: Optional[Union[str, bytes, List[Union[str, bytes]]]] = Field( + default=None, description="Reference image(s) for I2V/I2I." + ) + mask: Optional[Union[str, bytes, List[bytes]]] = Field( + default=None, description="Inpainting mask path or raw bytes." + ) + image_cond_strength: Optional[float] = Field( + default=None, description="Image conditioning strength." + ) + + # Per-prompt multiplier + num_images_per_prompt: int = Field(default=1, description="Number of images per prompt.") + + # Model-specific overflow + extra_params: Optional[Dict[str, Any]] = Field( + default=None, + description="Model-specific parameters. Use VisualGen.extra_param_specs " + "to discover valid keys for the loaded pipeline.", + ) diff --git a/tensorrt_llm/visual_gen/visual_gen.py b/tensorrt_llm/visual_gen/visual_gen.py index d675d5b537e7..80d5e5974d0d 100644 --- a/tensorrt_llm/visual_gen/visual_gen.py +++ b/tensorrt_llm/visual_gen/visual_gen.py @@ -22,7 +22,6 @@ import time import traceback import weakref -from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Literal, Optional, Union @@ -32,9 +31,19 @@ from tensorrt_llm._torch.visual_gen import DiffusionRequest, DiffusionResponse from tensorrt_llm._torch.visual_gen.executor import run_diffusion_worker from tensorrt_llm._torch.visual_gen.output import MediaOutput +from tensorrt_llm._torch.visual_gen.pipeline import ExtraParamSchema from tensorrt_llm.visual_gen.args import VisualGenArgs - -__all__ = ["VisualGen", "VisualGenParams", "MediaOutput", "VisualGenError", "VisualGenResult"] +from tensorrt_llm.visual_gen.params import VisualGenParams + +__all__ = [ + "VisualGen", + "VisualGenParams", + "ExtraParamSchema", + "MediaOutput", + "VisualGenError", + "VisualGenParamsError", + "VisualGenResult", +] from tensorrt_llm.executor.ipc import ZeroMqQueue from tensorrt_llm.inputs.data import VisualGenInputs from tensorrt_llm.llmapi.utils import set_api_status @@ -65,10 +74,22 @@ def get_ip_address() -> str: s.close() +@set_api_status("prototype") class VisualGenError(RuntimeError): """Base exception for all VisualGen operations.""" +@set_api_status("prototype") +class VisualGenParamsError(ValueError): + """Raised when request parameters fail validation. + + This covers unknown parameter keys, unsupported universal fields + for the loaded pipeline, type mismatches, and out-of-range values. + Caught by the executor so it returns an error response rather than + crashing the server. + """ + + class DiffusionRemoteClient: """Client proxy for remote DiffusionExecutor in worker processes.""" @@ -116,6 +137,10 @@ def __init__( # Wait for the background thread to initialize the event loop self.event_loop_ready.wait() + # Pipeline metadata — populated by _wait_ready from the READY signal. + self.default_generation_params: Dict = {} + self.extra_param_specs: Dict = {} + # Launch workers (VisualGenArgs is pickled via mp.Process spawn context) n_workers = self.n_workers logger.info(f"DiffusionClient: Launching {n_workers} workers") @@ -368,7 +393,14 @@ async def _wait_ready_async(self): while True: async with self.lock: if -1 in self.completed_responses: - self.completed_responses.pop(-1) + ready_resp = self.completed_responses.pop(-1) + # Extract pipeline metadata from the READY payload. + payload = ready_resp.output + if isinstance(payload, dict): + self.default_generation_params = payload.get( + "default_generation_params", {} + ) + self.extra_param_specs = payload.get("extra_param_specs", {}) elapsed = time.time() - start_time logger.info(f"DiffusionClient: Workers ready ({elapsed:.1f}s)") return @@ -389,6 +421,7 @@ async def _wait_ready_async(self): self.response_event.clear() +@set_api_status("prototype") class VisualGenResult: """Future-like object for async generation.""" @@ -447,69 +480,6 @@ def cancel(self): raise NotImplementedError("Cancel request (not yet implemented).") -@dataclass -@set_api_status("prototype") -class VisualGenParams: - """Parameters for visual generation. - - Attributes: - height: Output height in pixels - width: Output width in pixels - num_inference_steps: Number of denoising steps - guidance_scale: Classifier-free guidance scale - max_sequence_length: Maximum sequence length for text encoding - seed: Random seed for reproducibility - - # Video-specific parameters - num_frames: Number of video frames to generate - frame_rate: Frame rate for video output in fps - - # Image-specific parameters - num_images_per_prompt: Number of images to generate per prompt (for image models) - - # Advanced parameters - guidance_rescale: Guidance rescale factor (for some models) - output_type: Output type ("pt" for PyTorch tensors, "pil" for PIL images) - """ - - height: int = 720 - width: int = 1280 - num_inference_steps: int = 50 - guidance_scale: float = 5.0 - max_sequence_length: int = 512 - seed: int = 42 - - # Video-specific parameters - num_frames: int = 81 - frame_rate: float = 24.0 - input_reference: Optional[str] = None - image_cond_strength: float = 1.0 - - # Image-specific parameters - num_images_per_prompt: int = 1 - - # Image edit parameters - image: Optional[List[str]] = None - mask: Optional[str] = None - - # Advanced parameters - guidance_rescale: float = 0.0 - output_type: str = "pt" - - # LTX-2 multi-modal guidance (STG / modality guidance) - stg_scale: float = 0.0 - stg_blocks: Optional[List[int]] = None - modality_scale: float = 1.0 - rescale_scale: float = 0.0 - guidance_skip_step: int = 0 - enhance_prompt: bool = False - - # Wan-specific parameters - guidance_scale_2: Optional[float] = None - boundary_ratio: Optional[float] = None - last_image: Optional[str] = None - - class VisualGen: """High-level API for visual generation.""" @@ -529,6 +499,42 @@ def __init__( atexit.register(VisualGen._atexit_shutdown, weakref.ref(self)) + @property + def extra_param_specs(self) -> Dict[str, "ExtraParamSchema"]: + """Returns extra param specs for the loaded pipeline. + + Use this to discover types, ranges, and descriptions of + model-specific parameters passed via ``extra_params``. + """ + return self.executor.extra_param_specs + + @property + def default_params(self) -> "VisualGenParams": + """Returns a ``VisualGenParams`` with all defaults resolved for the loaded pipeline. + + Universal fields (height, width, etc.) are filled from the + pipeline's defaults. All declared ``extra_params`` keys are + included with their defaults (``None`` for params without one). + + Use this to inspect what the model will use, then modify and + pass to ``generate()``:: + + params = visual_gen.default_params + params.extra_params["stg_scale"] = 0.5 + params.height = 1024 + output = visual_gen.generate(inputs="a cat", params=params) + """ + kwargs = dict(self.executor.default_generation_params) + extra = {} + + for key, spec in self.executor.extra_param_specs.items(): + extra[key] = spec.default + + if extra: + kwargs["extra_params"] = extra + + return VisualGenParams(**kwargs) + @set_api_status("prototype") def generate( self, @@ -621,19 +627,9 @@ def generate_async( num_frames=params.num_frames, frame_rate=params.frame_rate, num_images_per_prompt=params.num_images_per_prompt, - guidance_rescale=params.guidance_rescale, - output_type=params.output_type, - stg_scale=params.stg_scale, - stg_blocks=params.stg_blocks, - modality_scale=params.modality_scale, - rescale_scale=params.rescale_scale, - guidance_skip_step=params.guidance_skip_step, - enhance_prompt=params.enhance_prompt, - image=params.input_reference, + image=params.image, image_cond_strength=params.image_cond_strength, - guidance_scale_2=params.guidance_scale_2, - boundary_ratio=params.boundary_ratio, - last_image=params.last_image, + extra_params=params.extra_params, ) self.executor.enqueue_requests([request]) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index f7cbdd219a14..a334d46a8d01 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -6538,7 +6538,7 @@ def test_nvfp4_4gpu_mtp_ar(self): assert accept_rate > 0.2, \ f"Acceptance rate too low for prompt {i}: {accept_rate:.2f}" - @skip_pre_hopper + @skip_pre_blackwell @pytest.mark.skip_less_device(4) @pytest.mark.skip_less_device_memory(80000) def test_bf16_4gpu_mtp_ar(self): diff --git a/tests/integration/defs/examples/test_visual_gen.py b/tests/integration/defs/examples/test_visual_gen.py index 82fcd0733ee1..77b80d805779 100644 --- a/tests/integration/defs/examples/test_visual_gen.py +++ b/tests/integration/defs/examples/test_visual_gen.py @@ -468,7 +468,7 @@ def _generate_ltx2_two_stage_video(llm_venv, output_subdir, linear_type="default vg_kwargs["parallel"] = {"dit_cfg_size": 2} diffusion_args = VisualGenArgs(**vg_kwargs) - visual_gen = VisualGen(model_path=model_path, diffusion_args=diffusion_args) + visual_gen = VisualGen(model=model_path, args=diffusion_args) try: params = VisualGenParams( diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 43266eba2d58..8031d2e640bc 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -97,6 +97,7 @@ l0_a10: - test_e2e.py::test_trtllm_bench_request_rate_and_concurrency[enable_concurrency-] - test_e2e.py::test_trtllm_bench_invalid_token_pytorch[TinyLlama-1.1B-Chat-v1.0-TinyLlama-1.1B-Chat-v1.0] # visual_gen + - unittest/_torch/visual_gen/test_visual_gen_params.py - unittest/_torch/visual_gen/test_media_storage.py # llmapi - unittest/llmapi/test_llm_utils.py diff --git a/tests/integration/test_lists/test-db/l0_b200.yml b/tests/integration/test_lists/test-db/l0_b200.yml index 08272f76ad59..17a17fb673ad 100644 --- a/tests/integration/test_lists/test-db/l0_b200.yml +++ b/tests/integration/test_lists/test-db/l0_b200.yml @@ -168,6 +168,7 @@ l0_b200: - kv_cache/test_kv_cache_iteration_stats.py::TestKvCacheIterationStats::test_field_completeness # ------------- Visual Gen tests --------------- - unittest/_torch/visual_gen/test_visual_gen_args.py + - unittest/_torch/visual_gen/test_visual_gen_params.py - unittest/_torch/visual_gen/test_warmup.py - unittest/_torch/visual_gen/test_teacache.py - unittest/_torch/visual_gen/test_fused_qkv.py diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index bb2344f7e8b6..cdf9fe13dfc1 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -313,7 +313,6 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_scheduler=True] SKIP (https://nvbugs/6037653) accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_pard[overlap_scheduler=False] SKIP (https://nvbugs/6037653) accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_latency] SKIP (https://nvbugs/6037654) -accuracy/test_llm_api_pytorch.py::TestPhi4::test_auto_dtype SKIP (https://nvbugs/6040098) perf/test_perf.py::test_perf[deepseek_r1_distill_qwen_32b-bench-_autodeploy-float16-input_output_len:1024,1024-reqs:512] SKIP (https://nvbugs/6044213) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-dp4-triton-auto] SKIP (https://nvbugs/6026678) accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[v1_kv_cache-ep4-triton-auto] SKIP (https://nvbugs/6026678) @@ -335,3 +334,20 @@ perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4 perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-v32-fp4_32k4k_con256_ctx1_dep4_gen1_dep32_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/5844149) perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_128k8k_con128_ctx1_pp8_gen1_dep16_eplb0_mtp2_ccb-UCX] SKIP (https://nvbugs/6060119) perf/test_perf_sanity.py::test_e2e[disagg_upload-gen_only-gb200_deepseek-r1-fp4_128k8k_con64_ctx1_pp8_gen1_dep32_eplb0_mtp3_ccb-UCX] SKIP (https://nvbugs/6060119) +accuracy/test_llm_api_pytorch.py::TestKimiK2::test_nvfp4[4gpus] SKIP (https://nvbugs/6069790) +accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_2_model_mtp[2model_trtllm] SKIP (https://nvbugs/5981293) +accuracy/test_llm_api_pytorch.py::TestGLM4_5Air::test_nvfp4_multi_gpus[throughput] SKIP (https://nvbugs/5981293) +disaggregated/test_disaggregated.py::test_disaggregated_trtllm_sampler[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/6069686) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/6071081) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=True-v2_kv_cache=True] SKIP (https://nvbugs/6071081) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=True-v2_kv_cache=False] SKIP (https://nvbugs/6071081) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-enable_chunked_prefill=False-v2_kv_cache=True] SKIP (https://nvbugs/6071081) +accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[google_gemma-3-1b-it-False] SKIP (https://nvbugs/6059036) +accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[meta-llama_Llama-3.1-8B-Instruct-False] SKIP (https://nvbugs/6059036) +accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[mistralai_Codestral-22B-v0.1-False] SKIP (https://nvbugs/6059036) +accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[mistralai_Ministral-8B-Instruct-2410-False] SKIP (https://nvbugs/6059036) +accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[nvidia_Llama-3.1-Nemotron-Nano-8B-v1-False] SKIP (https://nvbugs/6059036) +accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[Qwen_QwQ-32B-False] SKIP (https://nvbugs/6059036) +accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=False-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/6050489) +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=True] SKIP (https://nvbugs/6050489) +accuracy/test_llm_api_autodeploy.py::TestModelRegistryAccuracy::test_autodeploy_from_registry[deepseek-ai_DeepSeek-R1-0528-True] SKIP (https://nvbugs/6070955) diff --git a/tests/microbenchmarks/all_reduce.py b/tests/microbenchmarks/all_reduce.py index ca9d9a7610c3..66a0a19d287d 100644 --- a/tests/microbenchmarks/all_reduce.py +++ b/tests/microbenchmarks/all_reduce.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from argparse import ArgumentParser from itertools import product @@ -26,10 +27,14 @@ from cuda import cudart import tensorrt_llm as tllm +import tensorrt_llm.bindings.internal.userbuffers as ub from tensorrt_llm import Mapping from tensorrt_llm._torch.autotuner import AutoTuner, autotune +from tensorrt_llm._torch.custom_ops.userbuffers_custom_ops import \ + copy_to_userbuffers from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp, - Distributed) + Distributed, + userbuffers_allreduce_finalize) from tensorrt_llm._torch.modules.rms_norm import RMSNorm from tensorrt_llm._utils import (get_sm_version, local_mpi_rank, local_mpi_size, nvtx_range) @@ -52,6 +57,8 @@ def profile_allreduce( norm=None, scale=None, bias=None, + allreduce_instance=None, + dtype=None, ): allreduce_params = AllReduceParams( @@ -63,7 +70,8 @@ def profile_allreduce( bias=bias, ) - allreduce = AllReduce(mapping=mapping, strategy=strategy) + allreduce = allreduce_instance or AllReduce( + mapping=mapping, strategy=strategy, dtype=dtype) def func(x, loop_num=inner_loop): for _ in range(loop_num): @@ -273,6 +281,365 @@ def allreduce_benchmark( return df +# ── nccl-tests style comprehensive benchmark (--benchmark mode) ────────────── + +_STRATEGY_MAP = { + "NCCL": AllReduceStrategy.NCCL, + "NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC, + "UB": AllReduceStrategy.UB, + "ONESHOT": AllReduceStrategy.ONESHOT, + "TWOSHOT": AllReduceStrategy.TWOSHOT, + "AUTO": AllReduceStrategy.AUTO, + "MNNVL": AllReduceStrategy.MNNVL, +} +_UB_STRATEGIES = {AllReduceStrategy.NCCL_SYMMETRIC, AllReduceStrategy.UB} +_FUSION_MAP = { + "NONE": + AllReduceFusionOp.NONE, + "RESIDUAL_RMS_NORM": + AllReduceFusionOp.RESIDUAL_RMS_NORM, + "RESIDUAL_RMS_NORM_QUANT_FP8": + AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8, + "RESIDUAL_RMS_NORM_QUANT_NVFP4": + AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4, +} + + +def _fmt_size(nbytes): + """Format byte count as human-readable string (e.g. 256B, 4K, 1M, 2G).""" + if nbytes < 1024: + return f"{nbytes}B" + elif nbytes < 1024**2: + v = nbytes / 1024 + return f"{v:.0f}K" if nbytes % 1024 == 0 else f"{v:.1f}K" + elif nbytes < 1024**3: + v = nbytes / 1024**2 + return f"{v:.0f}M" if nbytes % (1024**2) == 0 else f"{v:.2f}M" + else: + v = nbytes / 1024**3 + return f"{v:.0f}G" if nbytes % (1024**3) == 0 else f"{v:.2f}G" + + +def _profile_ub(mapping, + dist, + allreduce, + fusion, + input, + residual, + norm, + scale, + enable_cudagraph=False, + inner_loop=200, + outer_loop=10): + """Profile UB allreduce kernel only (copy_to_ub and finalize are one-shot).""" + allreduce_params = AllReduceParams(fusion_op=fusion, + residual=residual, + norm_weight=norm.weight, + eps=norm.variance_epsilon, + scale=scale, + bias=None) + + # Copy input into user-buffer memory once (simulates matmul_to_ub in real flow) + ub_input = copy_to_userbuffers(input) + + def func(loop_num=inner_loop): + for _ in range(loop_num): + output = allreduce(ub_input, all_reduce_params=allreduce_params) + return output + + starts = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] + stops = [torch.cuda.Event(enable_timing=True) for _ in range(outer_loop)] + graph = torch.cuda.CUDAGraph() + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + # warmup + for _ in range(4): + func(loop_num=1) + if enable_cudagraph: + with torch.cuda.graph(graph, stream=stream): + func() + dist.barrier() + delay_kernel(20000, stream) + torch.cuda.synchronize() + for i in range(outer_loop): + starts[i].record(stream) + if enable_cudagraph: + graph.replay() + else: + func() + stops[i].record(stream) + torch.cuda.synchronize() + # Finalize once to sync (simulates userbuffers_allreduce_finalize in real flow) + output = func(loop_num=1) + userbuffers_allreduce_finalize(output[-1]) + runtimes = [starts[i].elapsed_time(stops[i]) for i in range(outer_loop)] + return sorted(runtimes)[len(runtimes) // 2] / inner_loop * 1000.0 + + +def _print_table(fusion_name, strategy_names, rows, world_size): + W_S, W_T, W_H, W_V, W_B = 10, 6, 6, 10, 16 + n = len(strategy_names) + print(flush=True) + print( + f"# Fusion: {fusion_name} world_size={world_size} " + f"algbw = size / time (GB/s)", + flush=True) + print("#", flush=True) + fixed = f"{'size':>{W_S}} {'ntok':>{W_T}} {'hdim':>{W_H}}" + sh = " ".join(f"{s:^{W_V * 2 + 2}}" for s in strategy_names) + print(f"# {fixed} {sh} {'BEST':>{W_B}}", flush=True) + pad = " " * (W_S + 2 + W_T + 2 + W_H) + mh = " ".join(f"{'time(us)':>{W_V}} {'algbw':>{W_V}}" + for _ in strategy_names) + print(f"# {pad} {mh} {' ':>{W_B}}", flush=True) + tw = 2 + W_S + 2 + W_T + 2 + W_H + 2 + n * (W_V * 2 + 2) + (n - + 1) * 2 + 2 + W_B + print("#" + "-" * (tw - 1), flush=True) + for row in rows: + prefix = (f" {row['size_human']:>{W_S}} " + f"{row['num_tokens']:>{W_T}} " + f"{row['hidden_size']:>{W_H}}") + vals, best_name, best_time = [], "N/A", float("inf") + for s in strategy_names: + t, bw = row.get(f"{s}_time"), row.get(f"{s}_algbw") + if t is not None: + vals.append(f"{t:>{W_V}.2f} {bw:>{W_V}.2f}") + if t < best_time: + best_time, best_name = t, s + else: + vals.append(f"{'N/A':>{W_V}} {'N/A':>{W_V}}") + print(f"{prefix} {' '.join(vals)} {best_name:>{W_B}}", flush=True) + + +def allreduce_benchmark_all( + dtype='bfloat16', + test_range="256,268435456,2", + explore_2d=False, + enable_cudagraph=False, + strategy_names=None, + fusion_names=None, + inner_loop=200, + outer_loop=10, + save_csv=None, +): + """Comprehensive benchmark: one table per fusion, all strategies side by side.""" + import csv as csv_mod + + world_size = tllm.mpi_world_size() + rank = tllm.mpi_rank() + local_rank = local_mpi_rank() + gpus_per_node = local_mpi_size() + + torch.cuda.set_device(local_rank) + cudart.cudaSetDevice(local_rank) + + mapping = Mapping(world_size, rank, gpus_per_node, tp_size=world_size) + logger.set_rank(mapping.rank) + AutoTuner.get().setup_distributed_state(mapping) + dist = Distributed.get(mapping) + sm_version = get_sm_version() + + if world_size == 1: + raise RuntimeError("Benchmark requires mpi_world_size > 1") + + torch_dtype = tllm._utils.str_dtype_to_torch(dtype) + elem_size = torch.finfo(torch_dtype).bits // 8 + + # Enable MNNVL testing on single-node (bypasses multi-node NVLink check) + os.environ["TLLM_TEST_MNNVL"] = "1" + + # strategies + if strategy_names is None: + strategy_names = [ + "NCCL", "NCCL_SYMMETRIC", "UB", "ONESHOT", "TWOSHOT", "AUTO", + "MNNVL" + ] + strategies = [_STRATEGY_MAP[s] for s in strategy_names] + + # fusions + if fusion_names is None: + fusion_names = list(_FUSION_MAP.keys()) + fusions = [] + for f in fusion_names: + fop = _FUSION_MAP[f] + if fop == AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4 and sm_version < 100: + if rank == 0: + print(f"[WARN] {f} requires SM100+, skipping.", flush=True) + continue + fusions.append((f, fop)) + + # shapes + if explore_2d: + num_tokens_list = [ + 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384 + ] + hidden_size_list = [128, 256, 512, 1024, 2048, 4096, 8192] + shape_list = list(product(num_tokens_list, hidden_size_list)) + else: + min_bytes, max_bytes, ratio = [int(i) for i in test_range.split(",")] + shape_list = [] + nbytes = min_bytes + while nbytes <= max_bytes: + total_elems = nbytes // elem_size + if total_elems <= 4096: + shape_list.append((1, max(total_elems, 1))) + else: + shape_list.append((total_elems // 4096, 4096)) + nbytes *= ratio + + # init user-buffers + need_ub = bool(_UB_STRATEGIES & set(strategies)) + if need_ub: + if ub.ub_supported(): + max_elems = max(s[0] * s[1] for s in shape_list) + ub.initialize_userbuffers_manager(world_size, 1, 1, rank, + torch.cuda.device_count(), + max_elems * elem_size) + else: + if rank == 0: + print("[WARN] ub not supported, skipping UB-based strategies.", + flush=True) + strategies = [s for s in strategies if s not in _UB_STRATEGIES] + strategy_names = [s.name for s in strategies] + + # create AllReduce instances + ar_instances = {} + for strat in strategies: + try: + ar_instances[strat] = AllReduce(mapping=mapping, + strategy=strat, + dtype=torch_dtype) + except Exception as e: + if rank == 0: + print(f"[WARN] Cannot init {strat.name}: {e}", flush=True) + strategies = [s for s in strategies if s in ar_instances] + strategy_names = [s.name for s in strategies] + + max_workspace = CustomAllReduceHelper.max_workspace_size_auto( + mapping.tp_size) + + if rank == 0: + print(f"\n{'=' * 80}", flush=True) + print(" TRT-LLM AllReduce Benchmark", flush=True) + print( + f" world_size={world_size} dtype={dtype} SM={sm_version}" + f" cudagraph={enable_cudagraph}" + f" inner={inner_loop} outer={outer_loop}", + flush=True) + print(f" Strategies : {', '.join(strategy_names)}", flush=True) + print(f" Fusions : {', '.join(f for f, _ in fusions)}", flush=True) + print(f"{'=' * 80}", flush=True) + + csv_rows = [] + + for fusion_name, fusion_op in fusions: + table_rows = [] + for num_tokens, hidden_size in shape_list: + msg_bytes = num_tokens * hidden_size * elem_size + inp = torch.ones((num_tokens, hidden_size), + dtype=torch_dtype, + device="cuda") + res = torch.randn_like(inp) + norm = RMSNorm(hidden_size=hidden_size, dtype=torch_dtype, + eps=1e-5).cuda() + norm.weight.data.copy_( + torch.randn((hidden_size, ), dtype=torch_dtype, device="cuda")) + scale = torch.tensor(1.0, dtype=torch.float32).cuda() + + row = dict(size_human=_fmt_size(msg_bytes), + num_tokens=num_tokens, + hidden_size=hidden_size, + size_bytes=msg_bytes) + + for strat in strategies: + sn = strat.name + # skip invalid combos + skip = False + if strat == AllReduceStrategy.TWOSHOT and num_tokens < world_size: + skip = True + elif strat in (AllReduceStrategy.ONESHOT, AllReduceStrategy.TWOSHOT) \ + and msg_bytes > max_workspace: + skip = True + elif strat == AllReduceStrategy.UB and fusion_op == AllReduceFusionOp.NONE: + skip = True + + if skip: + row[f"{sn}_time"] = row[f"{sn}_algbw"] = None + else: + try: + if strat == AllReduceStrategy.UB: + t_us = _profile_ub(mapping, dist, + ar_instances[strat], fusion_op, + inp, res, norm, scale, + enable_cudagraph, inner_loop, + outer_loop) + else: + t_us = profile_allreduce( + mapping=mapping, + dist=dist, + enable_cudagraph=enable_cudagraph, + inner_loop=inner_loop, + outer_loop=outer_loop, + fusion=fusion_op, + input=inp, + residual=res, + norm=norm, + scale=scale, + allreduce_instance=ar_instances[strat]) * 1000.0 + row[f"{sn}_time"] = t_us + row[f"{sn}_algbw"] = msg_bytes / (t_us / 1e6) / 1e9 + except Exception as e: + if rank == 0: + print( + f" [SKIP] {sn} @ {_fmt_size(msg_bytes)}: {e}", + flush=True) + row[f"{sn}_time"] = row[f"{sn}_algbw"] = None + + csv_rows.append({ + "world_size": world_size, + "dtype": dtype, + "fusion": fusion_name, + "num_tokens": num_tokens, + "hidden_size": hidden_size, + "size_bytes": msg_bytes, + "strategy": sn, + "time_us": row[f"{sn}_time"] or 0.0, + "algbw_GBps": row[f"{sn}_algbw"] or 0.0, + }) + table_rows.append(row) + + if rank == 0: + _print_table(fusion_name, strategy_names, table_rows, world_size) + + # summary + if rank == 0: + print(f"\n{'=' * 80}", flush=True) + print(" Summary: peak algbw (GB/s) per strategy per fusion", + flush=True) + print(f"{'=' * 80}", flush=True) + hdr = f" {'fusion':<35s}" + "".join(f" {s:>14s}" + for s in strategy_names) + print(hdr, flush=True) + print(" " + "-" * (len(hdr) - 2), flush=True) + for fn, _ in fusions: + line = f" {fn:<35s}" + for sn in strategy_names: + bws = [ + r["algbw_GBps"] for r in csv_rows if r["fusion"] == fn + and r["strategy"] == sn and r["algbw_GBps"] > 0 + ] + line += f" {max(bws) if bws else 0.0:>14.2f}" + print(line, flush=True) + print(flush=True) + + if rank == 0 and save_csv and csv_rows: + with open(save_csv, "w", newline="") as f: + writer = csv_mod.DictWriter(f, fieldnames=csv_rows[0].keys()) + writer.writeheader() + writer.writerows(csv_rows) + print(f"Results saved to {save_csv}", flush=True) + + if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--dtype", "-t", default="bfloat16") @@ -285,14 +652,28 @@ def allreduce_benchmark( parser.add_argument("--enable_cudagraph", action="store_true") parser.add_argument("--save_csv", type=str, default=None) parser.add_argument("--enable_auto", action="store_true", default=False) + parser.add_argument("--benchmark", + action="store_true", + default=False, + help="Run comprehensive benchmark across all backends " + "with nccl-tests style output") args = parser.parse_args() - allreduce_benchmark( - args.dtype, - args.range, - args.enable_cudagraph, - args.explore_2d, - args.save_csv, - args.enable_auto, - ) + if args.benchmark: + allreduce_benchmark_all( + dtype=args.dtype, + test_range=args.range, + explore_2d=args.explore_2d, + enable_cudagraph=args.enable_cudagraph, + save_csv=args.save_csv, + ) + else: + allreduce_benchmark( + args.dtype, + args.range, + args.enable_cudagraph, + args.explore_2d, + args.save_csv, + args.enable_auto, + ) diff --git a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py index b1054a27dd41..79fa32dc74e0 100644 --- a/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py +++ b/tests/unittest/_torch/attention/sparse/test_dsa_indexer.py @@ -716,7 +716,7 @@ def test_indexer_k_cache_scatter_custom_op(): dtype=torch.bfloat16) k_fp8, k_scale = fp8_utils.fp8_quantize_1x128_sf_transpose(k_original) - # Prepare byte-level data + # Prepare byte-level data for the Python reference path scale_size = k_scale.shape[1] * 4 k_fp8_bytes = k_fp8.view(-1).view(torch.uint8).view(num_tokens, head_dim) k_scale_flat = k_scale.view(-1) @@ -755,9 +755,10 @@ def test_indexer_k_cache_scatter_custom_op(): # ========== Path 1: CUDA Kernel ========== print("\n=== Path 1: CUDA Kernel ===") - torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8_bytes, k_scale_bytes, - k_cache_cuda, flat_indices_fp8, - flat_indices_scale) + torch.ops.trtllm.indexer_k_cache_scatter_op(k_fp8, k_scale, k_cache_cuda, + metadata.slot_mapping_fp8, + metadata.slot_mapping_scale, + num_tokens) torch.cuda.synchronize() print("✓ CUDA kernel completed") diff --git a/tests/unittest/_torch/visual_gen/test_visual_gen_params.py b/tests/unittest/_torch/visual_gen/test_visual_gen_params.py new file mode 100644 index 000000000000..212ae3f00c70 --- /dev/null +++ b/tests/unittest/_torch/visual_gen/test_visual_gen_params.py @@ -0,0 +1,844 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for VisualGenParams, ExtraParamSchema, pipeline defaults, default merging, and validation.""" + +from unittest.mock import MagicMock, patch + +import pytest + +# ============================================================================= +# VisualGenParams — Pydantic validation +# ============================================================================= + + +class TestVisualGenParamsValidation: + """VisualGenParams is a Pydantic StrictBaseModel with correct defaults.""" + + def test_default_construction(self): + from tensorrt_llm.visual_gen import VisualGenParams + + params = VisualGenParams() + # Universal fields default to None + assert params.height is None + assert params.width is None + assert params.num_inference_steps is None + assert params.guidance_scale is None + assert params.max_sequence_length is None + assert params.num_frames is None + assert params.frame_rate is None + assert params.negative_prompt is None + assert params.image is None + assert params.mask is None + assert params.image_cond_strength is None + # Concrete defaults + assert params.seed == 42 + assert params.num_images_per_prompt == 1 + # Extra params + assert params.extra_params is None + + def test_explicit_values(self): + from tensorrt_llm.visual_gen import VisualGenParams + + params = VisualGenParams( + height=720, + width=1280, + num_inference_steps=50, + guidance_scale=5.0, + seed=123, + ) + assert params.height == 720 + assert params.width == 1280 + assert params.num_inference_steps == 50 + assert params.guidance_scale == 5.0 + assert params.seed == 123 + + def test_unknown_field_rejected(self): + from pydantic import ValidationError + + from tensorrt_llm.visual_gen import VisualGenParams + + with pytest.raises(ValidationError): + VisualGenParams(stg_scale=0.5) + + def test_extra_params_accepted(self): + from tensorrt_llm.visual_gen import VisualGenParams + + params = VisualGenParams(extra_params={"stg_scale": 0.5, "enhance_prompt": True}) + assert params.extra_params["stg_scale"] == 0.5 + assert params.extra_params["enhance_prompt"] is True + + def test_image_accepts_str(self): + from tensorrt_llm.visual_gen import VisualGenParams + + params = VisualGenParams(image="/path/to/image.png") + assert params.image == "/path/to/image.png" + + def test_image_accepts_bytes(self): + from tensorrt_llm.visual_gen import VisualGenParams + + params = VisualGenParams(image=b"\x89PNG") + assert params.image == b"\x89PNG" + + def test_image_accepts_list(self): + from tensorrt_llm.visual_gen import VisualGenParams + + params = VisualGenParams(image=["/path/a.png", b"\x89PNG"]) + assert len(params.image) == 2 + + def test_model_dump(self): + from tensorrt_llm.visual_gen import VisualGenParams + + params = VisualGenParams(height=512, seed=1) + d = params.model_dump() + assert d["height"] == 512 + assert d["seed"] == 1 + assert d["width"] is None + + def test_negative_prompt_on_params(self): + from tensorrt_llm.visual_gen import VisualGenParams + + params = VisualGenParams(negative_prompt="blurry, low quality") + assert params.negative_prompt == "blurry, low quality" + + +# ============================================================================= +# ExtraParamSchema +# ============================================================================= + + +class TestExtraParamSchema: + """ExtraParamSchema construction and field access.""" + + def test_basic_construction(self): + from tensorrt_llm._torch.visual_gen.pipeline import ExtraParamSchema + + spec = ExtraParamSchema(type="float", default=0.0, description="test param") + assert spec.type == "float" + assert spec.default == 0.0 + assert spec.description == "test param" + assert spec.range is None + + def test_with_range(self): + from tensorrt_llm._torch.visual_gen.pipeline import ExtraParamSchema + + spec = ExtraParamSchema(type="float", range=(0.0, 1.0)) + assert spec.range == (0.0, 1.0) + + def test_none_default(self): + from tensorrt_llm._torch.visual_gen.pipeline import ExtraParamSchema + + spec = ExtraParamSchema(type="str") + assert spec.default is None + + def test_public_import(self): + from tensorrt_llm import ExtraParamSchema + + spec = ExtraParamSchema(type="int", default=42) + assert spec.default == 42 + + +# ============================================================================= +# Pipeline DEFAULT_GENERATION_PARAMS and EXTRA_PARAM_SPECS +# ============================================================================= + + +class TestPipelineDefaults: + """Each pipeline declares correct default generation params.""" + + def test_wan_defaults(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + d = WanPipeline.DEFAULT_GENERATION_PARAMS + assert d["height"] == 480 + assert d["width"] == 832 + assert d["num_inference_steps"] == 50 + assert d["guidance_scale"] == 5.0 + assert d["num_frames"] == 81 + + def test_flux_defaults(self): + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + + d = FluxPipeline.DEFAULT_GENERATION_PARAMS + assert d["height"] == 1024 + assert d["width"] == 1024 + assert d["guidance_scale"] == 3.5 + + def test_ltx2_defaults(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + d = LTX2Pipeline.DEFAULT_GENERATION_PARAMS + assert d["height"] == 512 + assert d["width"] == 768 + assert d["num_inference_steps"] == 40 + assert d["guidance_scale"] == 4.0 + assert d["max_sequence_length"] == 1024 + assert d["num_frames"] == 121 + + def test_base_pipeline_empty_defaults(self): + from tensorrt_llm._torch.visual_gen.pipeline import BasePipeline + + assert BasePipeline.DEFAULT_GENERATION_PARAMS == {} + assert BasePipeline.EXTRA_PARAM_SPECS == {} + + +class TestPipelineExtraParamSpecs: + """Each pipeline declares correct extra param specs.""" + + def test_wan_extra_specs(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + specs = WanPipeline.EXTRA_PARAM_SPECS + assert "guidance_scale_2" in specs + assert "boundary_ratio" in specs + assert specs["guidance_scale_2"].type == "float" + assert specs["boundary_ratio"].range == (0.0, 1.0) + + def test_wan_i2v_extra_specs(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import ( + WanImageToVideoPipeline, + ) + + specs = WanImageToVideoPipeline.EXTRA_PARAM_SPECS + assert "last_image" in specs + assert "guidance_scale_2" in specs + assert specs["last_image"].type == "str" + + def test_flux_no_extra_specs(self): + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + + assert FluxPipeline.EXTRA_PARAM_SPECS == {} + + def test_ltx2_extra_specs(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + specs = LTX2Pipeline.EXTRA_PARAM_SPECS + expected_keys = { + "output_type", + "guidance_rescale", + "stg_scale", + "stg_blocks", + "modality_scale", + "rescale_scale", + "guidance_skip_step", + "enhance_prompt", + } + assert set(specs.keys()) == expected_keys + assert specs["stg_scale"].default == 0.0 + assert specs["enhance_prompt"].default is False + assert specs["stg_blocks"].default is None + + def test_ltx2_extra_specs_attribute_access(self): + """Direct attribute-style access works: Pipeline.EXTRA_PARAM_SPECS['key'].""" + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + assert LTX2Pipeline.EXTRA_PARAM_SPECS["modality_scale"].type == "float" + assert LTX2Pipeline.EXTRA_PARAM_SPECS["modality_scale"].default == 1.0 + + +# ============================================================================= +# Executor default merging +# ============================================================================= + + +class TestDefaultMerging: + """DiffusionExecutor._merge_defaults fills None fields correctly.""" + + def _make_mock_executor(self, pipeline_cls): + """Create a mock DiffusionExecutor with the given pipeline class's specs.""" + executor = MagicMock() + executor.pipeline = MagicMock() + executor.pipeline.DEFAULT_GENERATION_PARAMS = pipeline_cls.DEFAULT_GENERATION_PARAMS + executor.pipeline.EXTRA_PARAM_SPECS = pipeline_cls.EXTRA_PARAM_SPECS + return executor + + def _make_request(self, **kwargs): + from tensorrt_llm._torch.visual_gen.executor import DiffusionRequest + + return DiffusionRequest(request_id=0, prompt=["test"], **kwargs) + + def _merge(self, executor, req): + from tensorrt_llm._torch.visual_gen.executor import DiffusionExecutor + + DiffusionExecutor._merge_defaults(executor, req) + + def test_universal_defaults_merged(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + executor = self._make_mock_executor(WanPipeline) + req = self._make_request() + assert req.height is None + + self._merge(executor, req) + assert req.height == 480 + assert req.width == 832 + assert req.num_inference_steps == 50 + + def test_user_values_not_overwritten(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + executor = self._make_mock_executor(WanPipeline) + req = self._make_request(height=1080, width=1920) + + self._merge(executor, req) + assert req.height == 1080 # User value preserved + assert req.width == 1920 + assert req.num_inference_steps == 50 # Default filled + + def test_extra_params_defaults_merged(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + executor = self._make_mock_executor(LTX2Pipeline) + req = self._make_request() + + self._merge(executor, req) + assert req.extra_params is not None + assert req.extra_params["stg_scale"] == 0.0 + assert req.extra_params["output_type"] == "pt" + assert req.extra_params["enhance_prompt"] is False + # None defaults are also filled + assert req.extra_params["stg_blocks"] is None + + def test_user_extra_params_not_overwritten(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + executor = self._make_mock_executor(LTX2Pipeline) + req = self._make_request(extra_params={"stg_scale": 0.5}) + + self._merge(executor, req) + assert req.extra_params["stg_scale"] == 0.5 # User value preserved + assert req.extra_params["output_type"] == "pt" # Default filled + + def test_no_extra_params_for_flux(self): + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + + executor = self._make_mock_executor(FluxPipeline) + req = self._make_request() + + self._merge(executor, req) + assert req.extra_params is None # Flux has no extra specs + + def test_all_declared_keys_present_after_merge(self): + """After merge, all EXTRA_PARAM_SPECS keys are in extra_params.""" + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + executor = self._make_mock_executor(LTX2Pipeline) + req = self._make_request(extra_params={"stg_scale": 0.5}) + + self._merge(executor, req) + for key in LTX2Pipeline.EXTRA_PARAM_SPECS: + assert key in req.extra_params, f"Missing key: {key}" + + +# ============================================================================= +# VisualGen.default_params and extra_param_specs +# ============================================================================= + + +class TestVisualGenDefaultParams: + """VisualGen.default_params returns correctly merged params per pipeline. + + VisualGen delegates to executor.default_generation_params and + executor.extra_param_specs (populated from the READY signal). + """ + + def _make_visual_gen(self, pipeline_cls): + """Create VisualGen with mocked init and executor carrying pipeline metadata.""" + from tensorrt_llm.visual_gen import VisualGen + + with patch.object(VisualGen, "__init__", lambda self, *a, **kw: None): + vg = VisualGen.__new__(VisualGen) + vg.executor = MagicMock() + if pipeline_cls is not None: + vg.executor.default_generation_params = pipeline_cls.DEFAULT_GENERATION_PARAMS + vg.executor.extra_param_specs = pipeline_cls.EXTRA_PARAM_SPECS + else: + vg.executor.default_generation_params = {} + vg.executor.extra_param_specs = {} + return vg + + def test_ltx2_default_params(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + vg = self._make_visual_gen(LTX2Pipeline) + params = vg.default_params + assert params.height == 512 + assert params.width == 768 + assert params.num_inference_steps == 40 + assert params.seed == 42 + assert params.extra_params is not None + assert params.extra_params["stg_scale"] == 0.0 + assert params.extra_params["output_type"] == "pt" + # None-default keys are present + assert "stg_blocks" in params.extra_params + + def test_wan_default_params(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + vg = self._make_visual_gen(WanPipeline) + params = vg.default_params + assert params.height == 480 + assert params.width == 832 + assert params.extra_params is not None + assert "guidance_scale_2" in params.extra_params + assert "boundary_ratio" in params.extra_params + + def test_flux_default_params_no_extra(self): + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + + vg = self._make_visual_gen(FluxPipeline) + params = vg.default_params + assert params.height == 1024 + assert params.width == 1024 + assert params.extra_params is None + + def test_no_pipeline_returns_bare_params(self): + vg = self._make_visual_gen(None) + params = vg.default_params + assert params.height is None + assert params.extra_params is None + + def test_extra_param_specs(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + vg = self._make_visual_gen(LTX2Pipeline) + specs = vg.extra_param_specs + assert "stg_scale" in specs + assert specs["stg_scale"].type == "float" + + def test_extra_param_specs_empty_for_flux(self): + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + + vg = self._make_visual_gen(FluxPipeline) + assert vg.extra_param_specs == {} + + +# ============================================================================= +# Pipeline metadata bridging (executor → client) +# ============================================================================= + + +class TestPipelineMetadataBridging: + """Verify DEFAULT_GENERATION_PARAMS and EXTRA_PARAM_SPECS survive + the pickle round-trip from DiffusionExecutor READY signal to the client.""" + + def _build_ready_response(self, pipeline_cls): + """Build a DiffusionResponse matching what DiffusionExecutor sends.""" + from tensorrt_llm._torch.visual_gen.executor import DiffusionResponse + + return DiffusionResponse( + request_id=-1, + output={ + "status": "READY", + "default_generation_params": pipeline_cls.DEFAULT_GENERATION_PARAMS, + "extra_param_specs": pipeline_cls.EXTRA_PARAM_SPECS, + }, + ) + + def _roundtrip(self, response): + """Pickle/unpickle to simulate ZMQ transport.""" + import pickle + + return pickle.loads(pickle.dumps(response)) + + def test_ready_payload_pickle_roundtrip(self): + """The READY dict survives pickle (the ZMQ transport layer).""" + from tensorrt_llm._torch.visual_gen.executor import DiffusionResponse + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + resp = self._build_ready_response(LTX2Pipeline) + restored = self._roundtrip(resp) + + assert isinstance(restored, DiffusionResponse) + assert restored.request_id == -1 + payload = restored.output + assert isinstance(payload, dict) + assert payload["status"] == "READY" + assert payload["default_generation_params"] == LTX2Pipeline.DEFAULT_GENERATION_PARAMS + assert set(payload["extra_param_specs"].keys()) == set( + LTX2Pipeline.EXTRA_PARAM_SPECS.keys() + ) + + def test_extra_param_schema_type_preserved(self): + """ExtraParamSchema instances keep their type through pickle.""" + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + from tensorrt_llm._torch.visual_gen.pipeline import ExtraParamSchema + + resp = self._build_ready_response(LTX2Pipeline) + restored = self._roundtrip(resp) + + specs = restored.output["extra_param_specs"] + for key, spec in specs.items(): + assert isinstance(spec, ExtraParamSchema), ( + f"spec '{key}' lost its type: got {type(spec).__name__}" + ) + + def test_extra_param_schema_fields_preserved(self): + """ExtraParamSchema field values survive the round-trip.""" + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + resp = self._build_ready_response(LTX2Pipeline) + restored = self._roundtrip(resp) + + specs = restored.output["extra_param_specs"] + original = LTX2Pipeline.EXTRA_PARAM_SPECS + for key in original: + assert specs[key].type == original[key].type + assert specs[key].default == original[key].default + assert specs[key].range == original[key].range + assert specs[key].description == original[key].description + + def test_wan_pipeline_roundtrip(self): + """Wan pipeline metadata survives the round-trip.""" + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + resp = self._build_ready_response(WanPipeline) + restored = self._roundtrip(resp) + + payload = restored.output + assert payload["default_generation_params"]["height"] == 480 + assert payload["default_generation_params"]["num_frames"] == 81 + assert "guidance_scale_2" in payload["extra_param_specs"] + assert "boundary_ratio" in payload["extra_param_specs"] + + def test_flux_empty_specs_roundtrip(self): + """Pipeline with no EXTRA_PARAM_SPECS round-trips as empty dict.""" + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + + resp = self._build_ready_response(FluxPipeline) + restored = self._roundtrip(resp) + + assert restored.output["extra_param_specs"] == {} + assert restored.output["default_generation_params"]["height"] == 1024 + + def test_client_extracts_metadata_from_ready(self): + """DiffusionRemoteClient stores metadata when processing a READY response.""" + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + from tensorrt_llm._torch.visual_gen.pipeline import ExtraParamSchema + from tensorrt_llm.visual_gen.visual_gen import DiffusionRemoteClient + + resp = self._build_ready_response(LTX2Pipeline) + restored = self._roundtrip(resp) + + # Simulate what _wait_ready_async does: extract from the response payload + client = MagicMock(spec=DiffusionRemoteClient) + client.default_generation_params = {} + client.extra_param_specs = {} + + payload = restored.output + if isinstance(payload, dict): + client.default_generation_params = payload.get("default_generation_params", {}) + client.extra_param_specs = payload.get("extra_param_specs", {}) + + assert client.default_generation_params == LTX2Pipeline.DEFAULT_GENERATION_PARAMS + assert set(client.extra_param_specs.keys()) == set(LTX2Pipeline.EXTRA_PARAM_SPECS.keys()) + for spec in client.extra_param_specs.values(): + assert isinstance(spec, ExtraParamSchema) + + +# ============================================================================= +# VisualGenParamsError — error class +# ============================================================================= + + +class TestVisualGenParamsError: + """VisualGenParamsError is importable and is a subclass of ValueError.""" + + def test_import_from_top_level(self): + from tensorrt_llm import VisualGenParamsError + + assert issubclass(VisualGenParamsError, ValueError) + + def test_import_from_visual_gen(self): + from tensorrt_llm.visual_gen import VisualGenParamsError + + assert VisualGenParamsError is not None + + def test_is_subclass_of_value_error(self): + from tensorrt_llm.visual_gen import VisualGenParamsError + + assert issubclass(VisualGenParamsError, ValueError) + assert not issubclass(VisualGenParamsError, RuntimeError) + + def test_raise_and_catch_as_value_error(self): + from tensorrt_llm.visual_gen import VisualGenParamsError + + with pytest.raises(ValueError): + raise VisualGenParamsError("bad param") + + def test_message_preserved(self): + from tensorrt_llm.visual_gen import VisualGenParamsError + + with pytest.raises(VisualGenParamsError, match="height.*out of range"): + raise VisualGenParamsError("height is out of range") + + +# ============================================================================= +# Request validation — _validate_request +# ============================================================================= + + +class TestRequestValidation: + """DiffusionExecutor._validate_request raises VisualGenParamsError on bad params.""" + + def _make_mock_executor(self, pipeline_cls): + executor = MagicMock() + executor.pipeline = MagicMock() + executor.pipeline.__class__ = pipeline_cls + executor.pipeline.DEFAULT_GENERATION_PARAMS = pipeline_cls.DEFAULT_GENERATION_PARAMS + executor.pipeline.EXTRA_PARAM_SPECS = pipeline_cls.EXTRA_PARAM_SPECS + return executor + + def _make_request(self, **kwargs): + from tensorrt_llm._torch.visual_gen.executor import DiffusionRequest + + return DiffusionRequest(request_id=0, prompt=["test"], **kwargs) + + def _validate(self, executor, req): + from tensorrt_llm._torch.visual_gen.executor import DiffusionExecutor + + DiffusionExecutor._validate_request(executor, req) + + def _merge_and_validate(self, executor, req): + from tensorrt_llm._torch.visual_gen.executor import DiffusionExecutor + + DiffusionExecutor._merge_defaults(executor, req) + DiffusionExecutor._validate_request(executor, req) + + # --- unknown extra_params --- + + def test_unknown_extra_params_raises(self): + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + from tensorrt_llm.visual_gen import VisualGenParamsError + + executor = self._make_mock_executor(FluxPipeline) + req = self._make_request(extra_params={"nonexistent_key": 42}) + with pytest.raises(VisualGenParamsError, match="Unknown extra_params"): + self._validate(executor, req) + + def test_unknown_extra_params_lists_supported_keys(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + from tensorrt_llm.visual_gen import VisualGenParamsError + + executor = self._make_mock_executor(LTX2Pipeline) + req = self._make_request(extra_params={"bad_key": 1}) + with pytest.raises(VisualGenParamsError, match="Supported"): + self._validate(executor, req) + + def test_valid_extra_params_accepted(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + executor = self._make_mock_executor(LTX2Pipeline) + req = self._make_request(extra_params={"stg_scale": 0.5}) + self._merge_and_validate(executor, req) # should not raise + + # --- unsupported universal fields --- + + def test_num_frames_on_image_pipeline_raises(self): + """num_frames=81 to FLUX (image-only) should raise.""" + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + from tensorrt_llm.visual_gen import VisualGenParamsError + + executor = self._make_mock_executor(FluxPipeline) + req = self._make_request(num_frames=81) + with pytest.raises(VisualGenParamsError, match="num_frames.*not use it"): + self._validate(executor, req) + + def test_frame_rate_on_image_pipeline_raises(self): + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + from tensorrt_llm.visual_gen import VisualGenParamsError + + executor = self._make_mock_executor(FluxPipeline) + req = self._make_request(frame_rate=24.0) + with pytest.raises(VisualGenParamsError, match="frame_rate.*not use it"): + self._validate(executor, req) + + def test_image_not_checked_by_validator(self): + """image is a conditioning input — validated at runtime by infer(), not here.""" + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + executor = self._make_mock_executor(WanPipeline) + req = self._make_request(image="/path/to/img.png") + # Should not raise — image validation is the pipeline's responsibility + self._merge_and_validate(executor, req) + + def test_num_frames_on_video_pipeline_ok(self): + """num_frames is declared by WanPipeline, should not raise.""" + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + executor = self._make_mock_executor(WanPipeline) + req = self._make_request(num_frames=81) + self._merge_and_validate(executor, req) + + def test_image_on_i2v_pipeline_ok(self): + """image is declared by WanImageToVideoPipeline, should not raise.""" + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import ( + WanImageToVideoPipeline, + ) + + executor = self._make_mock_executor(WanImageToVideoPipeline) + req = self._make_request(image="/path/to/img.png") + self._merge_and_validate(executor, req) + + def test_none_fields_not_flagged(self): + """Fields left as None should never trigger unsupported-field errors.""" + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + + executor = self._make_mock_executor(FluxPipeline) + req = self._make_request() # all None + self._merge_and_validate(executor, req) + + # --- type validation on extra_params --- + + def test_wrong_type_extra_param_raises(self): + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + from tensorrt_llm.visual_gen import VisualGenParamsError + + executor = self._make_mock_executor(LTX2Pipeline) + req = self._make_request(extra_params={"stg_scale": "not_a_number"}) + with pytest.raises(VisualGenParamsError, match="expected type 'float'"): + self._merge_and_validate(executor, req) + + def test_int_accepted_for_float_spec(self): + """An int value should be accepted for a 'float'-typed spec.""" + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + executor = self._make_mock_executor(LTX2Pipeline) + req = self._make_request(extra_params={"stg_scale": 1}) + self._merge_and_validate(executor, req) + + def test_bool_rejected_for_float_spec(self): + """A bool should not be accepted for a 'float' spec (bool is-a int in Python, + but semantically wrong for floats).""" + from tensorrt_llm._torch.visual_gen.models.ltx2.pipeline_ltx2 import LTX2Pipeline + + executor = self._make_mock_executor(LTX2Pipeline) + req = self._make_request(extra_params={"stg_scale": True}) + # bool is instance of int which is accepted for float, so this passes type check. + # This is intentional — Python's type hierarchy makes bool a subclass of int. + self._merge_and_validate(executor, req) + + def test_wrong_type_str_extra_param(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan_i2v import ( + WanImageToVideoPipeline, + ) + from tensorrt_llm.visual_gen import VisualGenParamsError + + executor = self._make_mock_executor(WanImageToVideoPipeline) + req = self._make_request( + image="/img.png", + extra_params={"last_image": 123}, + ) + with pytest.raises(VisualGenParamsError, match="expected type 'str'"): + self._merge_and_validate(executor, req) + + # --- range validation on extra_params --- + + def test_out_of_range_extra_param_raises(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + from tensorrt_llm.visual_gen import VisualGenParamsError + + executor = self._make_mock_executor(WanPipeline) + # boundary_ratio has range (0.0, 1.0) + req = self._make_request(extra_params={"boundary_ratio": 2.0}) + with pytest.raises(VisualGenParamsError, match="out of range"): + self._merge_and_validate(executor, req) + + def test_negative_boundary_ratio_raises(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + from tensorrt_llm.visual_gen import VisualGenParamsError + + executor = self._make_mock_executor(WanPipeline) + req = self._make_request(extra_params={"boundary_ratio": -0.5}) + with pytest.raises(VisualGenParamsError, match="out of range"): + self._merge_and_validate(executor, req) + + def test_boundary_value_at_range_edge_ok(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + executor = self._make_mock_executor(WanPipeline) + req = self._make_request(extra_params={"boundary_ratio": 0.0}) + self._merge_and_validate(executor, req) + + def test_boundary_value_at_range_max_ok(self): + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + executor = self._make_mock_executor(WanPipeline) + req = self._make_request(extra_params={"boundary_ratio": 1.0}) + self._merge_and_validate(executor, req) + + # --- multiple errors collected --- + + def test_multiple_errors_in_single_message(self): + """Multiple validation failures should be collected into one error.""" + from tensorrt_llm._torch.visual_gen.models.flux.pipeline_flux import FluxPipeline + from tensorrt_llm.visual_gen import VisualGenParamsError + + executor = self._make_mock_executor(FluxPipeline) + req = self._make_request( + num_frames=81, + frame_rate=24.0, + extra_params={"bogus": 1}, + ) + with pytest.raises(VisualGenParamsError) as exc_info: + self._validate(executor, req) + msg = str(exc_info.value) + assert "num_frames" in msg + assert "frame_rate" in msg + assert "bogus" in msg + + # --- None extra_params values skip type/range check --- + + def test_none_extra_param_value_skipped(self): + """None values for extra_params with range specs should not fail validation.""" + from tensorrt_llm._torch.visual_gen.models.wan.pipeline_wan import WanPipeline + + executor = self._make_mock_executor(WanPipeline) + req = self._make_request(extra_params={"boundary_ratio": None}) + self._merge_and_validate(executor, req) + + # --- process_request returns error response instead of crashing --- + + def test_process_request_returns_error_on_validation_failure(self): + """Validation errors become error responses, not server crashes.""" + from tensorrt_llm._torch.visual_gen.executor import DiffusionExecutor, DiffusionResponse + + # Build a mock with real method bindings for the three methods + # that process_request chains through. + executor = MagicMock() + executor.pipeline = MagicMock() + executor.pipeline.__class__.__name__ = "FluxPipeline" + executor.pipeline.DEFAULT_GENERATION_PARAMS = {"height": 1024, "width": 1024} + executor.pipeline.EXTRA_PARAM_SPECS = {} + executor.pipeline._warmed_up_shapes = set() + executor.pipeline.warmup_cache_key = MagicMock(return_value=(1024, 1024, None)) + executor.rank = 0 + executor.device_id = 0 + executor.response_queue = MagicMock() + + # Wire real methods onto the mock so process_request uses them + executor._merge_defaults = lambda req: DiffusionExecutor._merge_defaults(executor, req) + executor._validate_request = lambda req: DiffusionExecutor._validate_request(executor, req) + + req = self._make_request(num_frames=81, extra_params={"bad": 1}) + + # Call the real process_request + DiffusionExecutor.process_request(executor, req) + + # Should have put an error response, not crashed + executor.response_queue.put.assert_called_once() + resp = executor.response_queue.put.call_args[0][0] + assert isinstance(resp, DiffusionResponse) + assert resp.error_msg is not None + assert "validation failed" in resp.error_msg.lower() diff --git a/tests/unittest/disaggregated/test_router.py b/tests/unittest/disaggregated/test_router.py index 04f08737392b..22674ea309ff 100644 --- a/tests/unittest/disaggregated/test_router.py +++ b/tests/unittest/disaggregated/test_router.py @@ -8,8 +8,9 @@ from tensorrt_llm.serve.openai_protocol import (ChatCompletionRequest, CompletionRequest, DisaggregatedParams) -from tensorrt_llm.serve.router import (KvCacheAwareRouter, LoadBalancingRouter, - RoundRobinRouter, create_router) +from tensorrt_llm.serve.router import (ConversationRouter, KvCacheAwareRouter, + LoadBalancingRouter, RoundRobinRouter, + create_router) # Mock class for metadata server @@ -155,26 +156,22 @@ async def test_request_balancing_router(servers, requests_fixture, request): use_tokens=False) requests = request.getfixturevalue(requests_fixture) - server, _ = await router.get_next_server(requests[0]) - assert server == "server1" - server, _ = await router.get_next_server(requests[1]) - assert server == "server2" - server, _ = await router.get_next_server(requests[2]) - assert server == "server3" + # First 3 requests: all servers start at 0 load, each gets a unique server + assigned = {} + for i in range(3): + server, _ = await router.get_next_server(requests[i]) + assigned[i] = server + assert len(set(assigned.values())) == 3, "All 3 servers should be used" - # Similulate terminating 3rd request (on server 3) + # Finish 3rd request — its server drops to 0 (uniquely least loaded) await router.finish_request(requests[2]) - - # Now server3 is least loaded server, _ = await router.get_next_server(requests[3]) - assert server == "server3" + assert server == assigned[2] - # Simulate terminating 4th request (on server 3) + # Finish 2nd request — its server drops to 0 (uniquely least loaded) await router.finish_request(requests[1]) - - # Now server2 is least loaded server, _ = await router.get_next_server(requests[4]) - assert server == "server2" + assert server == assigned[1] @pytest.mark.asyncio @@ -185,52 +182,33 @@ async def test_tokens_balancing_router(servers, requests_fixture, request): use_tokens=True) requests = request.getfixturevalue(requests_fixture) + # prompt_lengths = [100, 500, 10, 400, 2000, 100] server_sequence = [(await router.get_next_server(req))[0] for req in requests] - # Loads at each step: - # Step 0: - # server1: 100 - # server2: 0 - # server3: 0 - - # Step 1: - # server1: 100 - # server2: 500 - # server3: 0 - - # Step 2: - # server1: 100 - # server2: 500 - # server3: 10 - - # Step 3: - # server1: 100 - # server2: 500 - # server3: 410 - - # Step 4: - # server1: 2100 - # server2: 500 - # server3: 410 - - # Step 5: - # server1: 2100 - # server2: 500 - # server3: 510 - assert server_sequence == [ - "server1", "server2", "server3", "server3", "server1", "server3" - ] + # Steps 0-1: tied loads → implementation-defined assignment. + # Step 2+: unique least-loaded → deterministic relative to steps 0-1. + s0, s1, s2 = server_sequence[0], server_sequence[1], server_sequence[2] + assert len({s0, s1, s2}) == 3, "All 3 servers should be used" + + # After step 2: s0=100, s1=500, s2=10 + # Step 3: s2 uniquely least (10 < 100 < 500) + assert server_sequence[3] == s2 - # Simulate terminating 5th request (on server 1) + # After step 3: s0=100, s1=500, s2=410 + # Step 4: s0 uniquely least (100 < 410 < 500) + assert server_sequence[4] == s0 + + # After step 4: s0=2100, s1=500, s2=410 + # Step 5: s2 uniquely least (410 < 500 < 2100) + assert server_sequence[5] == s2 + + # Finish 5th request (2000 tokens on s0) await router.finish_request(requests[4]) server, _ = await router.get_next_server(requests[4]) - # New loads: - #server1: 100 - #server2: 500 - #server3: 510 - assert server == "server1" + # After finish: s0=100, s1=500, s2=510 → s0 uniquely least + assert server == s0 @pytest.mark.asyncio @@ -251,6 +229,7 @@ async def test_gen_tokens_balancing_router(servers, requests_fixture, request): @pytest.mark.asyncio async def test_kv_cache_aware_router(servers): # create tokenized requests to skip tokenization + # req0: [1000]*100, req1: [1000]*50+[1001]*150, req2: [1002]*300 requests = [ CompletionRequest(model="TinyLlama", prompt=[[1000] * 100]), CompletionRequest(model="TinyLlama", @@ -264,8 +243,17 @@ async def test_kv_cache_aware_router(servers): max_batch_size=32, tokens_per_block=32) results = [await router.get_next_server(req) for req in requests] - servers, infos = zip(*results) - assert servers == ("server1", "server2", "server3") + assigned_servers, infos = zip(*results) + # Initial routing (empty caches): all 3 should get distinct servers + assert len(set(assigned_servers)) == 3 + + # Track which server cached which request + server_of = {i: assigned_servers[i] for i in range(3)} + all_servers = list(router._server_state.keys()) + + def matches_by_server(info): + """Map server → matched token count from positional matches list.""" + return dict(zip(all_servers, info["matches"])) # manually updates since no real server is involved for request in requests: @@ -279,44 +267,65 @@ async def test_kv_cache_aware_router(servers): assert infos[0]["block_hashes"][0][0] == infos[1]["block_hashes"][0][0] # no workloads, route by kv cache hits + # reversed: [req2, req1, req0] — each should route to its cached server results = [await router.get_next_server(req) for req in reversed(requests)] - servers, infos = zip(*results) - assert servers == ("server3", "server2", "server1") + hit_servers, hit_infos = zip(*results) + assert hit_servers == (server_of[2], server_of[1], server_of[0]) + # matched partial block will be counted as a whole block - assert infos[0]["matches"] == [0, 0, 320] - assert infos[1]["matches"] == [32, 224, 0] - assert infos[2]["matches"] == [128, 32, 0] + # req2 ([1002]*300): only matches server_of[2] → 320 tokens + m0 = matches_by_server(hit_infos[0]) + assert m0[server_of[2]] == 320 + assert m0[server_of[0]] == 0 + assert m0[server_of[1]] == 0 + # req1 ([1000]*50+[1001]*150): full match server_of[1] → 224, partial server_of[0] → 32 + m1 = matches_by_server(hit_infos[1]) + assert m1[server_of[1]] == 224 + assert m1[server_of[0]] == 32 + assert m1[server_of[2]] == 0 + # req0 ([1000]*100): full match server_of[0] → 128, partial server_of[1] → 32 + m2 = matches_by_server(hit_infos[2]) + assert m2[server_of[0]] == 128 + assert m2[server_of[1]] == 32 + assert m2[server_of[2]] == 0 for request in requests: await router.finish_request(request) - # block-wise (32/block) hit rate: 96/512, 32/512, 0/512 + # block-wise (32/block) hit rate: server_of[0]=96/512, server_of[1]=32/512, server_of[2]=0/512 another_request = CompletionRequest(model="TinyLlama", prompt=[[1000] * 500]) dup_requests = [copy.copy(another_request) for _ in range(20)] another_results = [ await router.get_next_server(req) for req in dup_requests ] - servers, infos = zip(*another_results) + dup_servers, dup_infos = zip(*another_results) # due to workload balancing, not all requests are sent to the same server - # distribution is related to the hit rate - counts = {server: 0 for server in servers} - for server in servers: - counts[server] += 1 - assert counts["server1"] > counts["server2"] > counts["server3"] > 0 - assert infos[0]["matches"] == [96, 32, 0] + # distribution follows cache hit rate + counts = {s: 0 for s in dup_servers} + for s in dup_servers: + counts[s] += 1 + assert counts[server_of[0]] > counts[server_of[1]] > counts[ + server_of[2]] > 0 + dup_m = matches_by_server(dup_infos[0]) + assert dup_m[server_of[0]] == 96 + assert dup_m[server_of[1]] == 32 + assert dup_m[server_of[2]] == 0 for req in dup_requests: await router.finish_request(req) - # test router after block eviction on server 1&2 - # results: server3(request2), server2(request1), server1(request0) - for server, infos in results[1:]: - assert server in ["server1", "server2"] - events = [{"type": "removed", "block_hashes": infos["block_hashes"][0]}] + # test router after block eviction on servers that cached req0 and req1 + # results[0] = (server_of[2], ...), results[1:] are server_of[1] and server_of[0] + for server, info in results[1:]: + assert server in [server_of[0], server_of[1]] + events = [{"type": "removed", "block_hashes": info["block_hashes"][0]}] router._server_state[server].update_with_events(events) + # Only server_of[2] still has cached blocks (req2) results = [await router.get_next_server(req) for req in reversed(requests)] - servers, infos = zip(*results) - assert servers == ("server3", "server1", "server2") + final_servers, _ = zip(*results) + # req2 routes to server_of[2] (full cache hit); others spread elsewhere + assert final_servers[0] == server_of[2] + assert len(set(final_servers)) == 3 @pytest.mark.asyncio @@ -324,6 +333,7 @@ async def test_kv_cache_aware_router(servers): async def test_kv_cache_aware_router_multi_turn_conversation(api_type): """Test that consecutive turns of a multi-turn conversation route to the same server due to KV cache prefix hits. + Verifies that consecutive turns route to the same server. Simulates two concurrent sessions inspired by agentic_data/dataset_sample2000.jsonl session sess-fca58a1f44cd: Turn 0: 68 hash_ids (system prompt + first user input) @@ -621,3 +631,165 @@ async def test_get_next_server_exclude_server_insufficient(router_class): await router.get_next_server(CompletionRequest(model="TinyLlama", prompt=[[10] * 10]), exclude_server=servers[0]) + + +# ── ConversationRouter tests ── + + +def _make_request(conversation_id=None, prompt="the " * 100): + params = DisaggregatedParams(request_type="context_only", + conversation_id=conversation_id) + return CompletionRequest(model="TinyLlama", + prompt=[prompt], + disaggregated_params=params) + + +@pytest.mark.asyncio +async def test_conversation_router_session_affinity_and_fallbacks(): + """Session affinity, exclude-server override, and server-removal reroute.""" + servers = ["server1", "server2", "server3"] + router = ConversationRouter(server_role=None, servers=servers) + + # Affinity: same conversation_id → same server + req = _make_request(conversation_id="sess-A") + first, _ = await router.get_next_server(req) + await router.finish_request(req) + for _ in range(3): + req = _make_request(conversation_id="sess-A") + s, _ = await router.get_next_server(req) + assert s == first + await router.finish_request(req) + + # Exclude: affinity overridden when mapped server is excluded + req = _make_request(conversation_id="sess-A") + s, _ = await router.get_next_server(req, exclude_server=first) + assert s != first + await router.finish_request(req) + + # Server removal: session re-routes to surviving server + req = _make_request(conversation_id="sess-B") + orig, _ = await router.get_next_server(req) + await router.finish_request(req) + await router.remove_server(orig) + req = _make_request(conversation_id="sess-B") + s, _ = await router.get_next_server(req) + assert s != orig and s in router.servers + await router.finish_request(req) + + +@pytest.mark.asyncio +async def test_conversation_router_load_balancing(): + """New sessions with distinct prompts are load-balanced across servers.""" + servers = ["server1", "server2", "server3"] + router = ConversationRouter(server_role=None, servers=servers) + + assigned, reqs = [], [] + for i in range(3): + req = _make_request(prompt=f"unique topic {i} " * 50) + s, _ = await router.get_next_server(req) + assigned.append(s) + reqs.append(req) + assert sorted(assigned) == sorted(servers) + for req in reqs: + await router.finish_request(req) + + # Session affinity survives interleaved non-session requests + req_x = _make_request(conversation_id="sess-X", prompt="topic X") + sx, _ = await router.get_next_server(req_x) + await router.finish_request(req_x) + req_x2 = _make_request(conversation_id="sess-X", prompt="topic X turn 2") + s, _ = await router.get_next_server(req_x2) + assert s == sx + await router.finish_request(req_x2) + + +@pytest.mark.asyncio +async def test_conversation_router_prefix_and_token_id_paths(): + """Implicit prefix matching (text and token-ID paths) and hash_skip_count.""" + servers = ["server1", "server2", "server3"] + base = "x" * 2000 + + # Text prefix matching: multi-turn without conversation_id + router = ConversationRouter(server_role=None, servers=servers) + req1 = _make_request(prompt=base) + s1, _ = await router.get_next_server(req1) + await router.finish_request(req1) + for ext in ["y" * 50, "y" * 50 + "z" * 50]: + req = _make_request(prompt=base + ext) + s, _ = await router.get_next_server(req) + assert s == s1, "Extended prompt should prefix-match turn 1" + await router.finish_request(req) + + # Token-ID path: CompletionRequest with list[list[int]] + router2 = ConversationRouter(server_role=None, servers=servers) + base_ids = [1000] * 2000 + dp = DisaggregatedParams(request_type="context_only") + req_t1 = CompletionRequest(model="TinyLlama", + prompt=[base_ids + [0]], + disaggregated_params=dp) + st1, _ = await router2.get_next_server(req_t1) + await router2.finish_request(req_t1) + + req_t2 = CompletionRequest(model="TinyLlama", + prompt=[base_ids + [2000] * 50 + [0]], + disaggregated_params=dp) + st2, _ = await router2.get_next_server(req_t2) + assert st2 == st1, "Token-ID prefix should match" + await router2.finish_request(req_t2) + + # ChatCompletionRequest with prompt_token_ids + req_t3 = ChatCompletionRequest(model="TinyLlama", + messages=[{ + "role": "user", + "content": "dummy" + }], + prompt_token_ids=base_ids + [2000] * 50 + + [3000] * 50 + [0], + disaggregated_params=dp) + st3, _ = await router2.get_next_server(req_t3) + assert st3 == st1, "ChatCompletion token-ID path should match" + await router2.finish_request(req_t3) + + +@pytest.mark.asyncio +async def test_conversation_router_hash_skip_count(): + """hash_skip_count strips shared system-prompt prefix. + + With tokens_per_block=128 (chars, via code-point path): + - sys_prompt "S"*2000 → ~15 blocks, unique "A"*500 → ~4 blocks + - Total ~19 blocks, shared ratio ~15/19 ≈ 0.79 > 0.75 threshold + Without skip the shared prefix triggers a false implicit match. + With skip (hash_skip_count=400 → strips 400*5=2000 chars), the + shared prefix is removed and the remaining content differs. + """ + servers = ["server1", "server2", "server3"] + sys_prompt = "S" * 2000 + + # Without skip: shared prefix causes false match + r1 = ConversationRouter(server_role=None, servers=servers) + req_a = _make_request(prompt=sys_prompt + "A" * 500) + sa, _ = await r1.get_next_server(req_a) + await r1.finish_request(req_a) + req_b = _make_request(prompt=sys_prompt + "B" * 500) + sb, _ = await r1.get_next_server(req_b) + await r1.finish_request(req_b) + assert sb == sa, "Without skip, shared prefix causes false match" + + # With skip: different content after prefix → no match + r2 = ConversationRouter(server_role=None, + servers=servers, + hash_skip_count=400) + req_a2 = _make_request(prompt=sys_prompt + "A" * 500) + sa2, _ = await r2.get_next_server(req_a2) + # Keep in-flight so LB prefers a different server + req_b2 = _make_request(prompt=sys_prompt + "B" * 500) + sb2, _ = await r2.get_next_server(req_b2) + await r2.finish_request(req_a2) + await r2.finish_request(req_b2) + assert sb2 != sa2, "With skip, different content should not match" + + +def test_create_router_conversation(): + router = create_router(RouterConfig(type="conversation"), + ["server1", "server2"]) + assert isinstance(router, ConversationRouter)