diff --git a/.github/workflows/auto-assign-author.yml b/.github/workflows/auto-assign-author.yml index 6801379c0be..c4e655ce07b 100644 --- a/.github/workflows/auto-assign-author.yml +++ b/.github/workflows/auto-assign-author.yml @@ -15,4 +15,4 @@ jobs: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} PR_URL: ${{ github.event.pull_request.html_url }} AUTHOR: ${{ github.actor }} - run: gh pr edit $PR_URL --add-assignee $AUTHOR + run: gh pr edit $PR_URL --add-assignee $AUTHOR || echo "Could not assign $AUTHOR (not a collaborator), skipping." diff --git a/tensorrt_llm/_torch/autotuner.py b/tensorrt_llm/_torch/autotuner.py index 8b60d42b522..1c6684ea3e4 100644 --- a/tensorrt_llm/_torch/autotuner.py +++ b/tensorrt_llm/_torch/autotuner.py @@ -201,7 +201,7 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], tactic==-1 has special meaning, means the fallback kernel which should be able to implement any shapes This fallback tactic is needed for 2 reasons: - * when the autotuner cannot find a valid tactic in it's cache. + * when the autotuner cannot find a valid tactic in its cache. * in eager mode, w/o autotunning the custom op should have at least one kernel, which makes the autotuning process an optional process, such that user can opt out. @@ -1437,10 +1437,10 @@ def _create_tensor_like(self, origin_tensor: torch.Tensor, # during the tuning process. This can by controlled in the preparation phase by the runner. # It must not use all zero tensors. Otherwise the timing results become unreliable. if dtype == torch.float4_e2m1fn_x2: - return torch.randint(-5, 5, shapes, - device=device).to(torch.uint8).view(dtype) + return (torch.rand(shapes, device=device) * 10 - 5).to( + torch.uint8).view(dtype) else: - return torch.randint(-5, 5, shapes, device=device).to(dtype) + return (torch.rand(shapes, device=device) * 10 - 5).to(dtype) def _prepare_input_tensors( self, profile: OptimizationProfile, diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index f88d13ddbd3..dada2bc8d03 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -1069,7 +1069,7 @@ def get_valid_tactics(self, inputs: List[torch.Tensor], def get_dynamic_tensor_specs(cls) -> Tuple[DynamicTensorSpec, ...]: """Get the dynamic tensor specs for use with the AutoTuner.""" - # These indices correspond to the 0th input tensor and it's first dimension + # These indices correspond to the 0th input tensor and its first dimension # i.e. we are tuning M where the first input tensor is of shape [B, M, K] MAT1_IDX = 0 diff --git a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py index 9ec06dbe7fb..0addc29a566 100644 --- a/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/trtllm_gen_custom_ops.py @@ -284,7 +284,7 @@ def get_dynamic_tensor_specs(cls, ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 - MAX_PROFILE_BUCKET = 4096 + MAX_PROFILE_BUCKET = 8192 m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) @@ -660,7 +660,7 @@ def get_dynamic_tensor_specs(cls, ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 - MAX_PROFILE_BUCKET = 4096 + MAX_PROFILE_BUCKET = 8192 m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) @@ -967,7 +967,7 @@ def get_dynamic_tensor_specs(cls, ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 - MAX_PROFILE_BUCKET = 4096 + MAX_PROFILE_BUCKET = 8192 m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) @@ -1273,7 +1273,7 @@ def get_dynamic_tensor_specs(cls, ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 - MAX_PROFILE_BUCKET = 4096 + MAX_PROFILE_BUCKET = 8192 m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) @@ -1561,7 +1561,7 @@ def get_dynamic_tensor_specs(cls, ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 - MAX_PROFILE_BUCKET = 4096 + MAX_PROFILE_BUCKET = 8192 m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) @@ -1833,7 +1833,7 @@ def get_dynamic_tensor_specs(cls, ep_size: int) -> Tuple[DynamicTensorSpec, ...]: HIDDEN_STATES_IDX = 2 TUNED_DIM = 0 - MAX_PROFILE_BUCKET = 4096 + MAX_PROFILE_BUCKET = 8192 m_values = get_last_power_of_2_num_tokens_buckets(MAX_PROFILE_BUCKET) diff --git a/tensorrt_llm/_torch/models/modeling_exaone_moe.py b/tensorrt_llm/_torch/models/modeling_exaone_moe.py index d960d515235..fe420178558 100644 --- a/tensorrt_llm/_torch/models/modeling_exaone_moe.py +++ b/tensorrt_llm/_torch/models/modeling_exaone_moe.py @@ -62,7 +62,21 @@ def check_is_moe(config: ExaoneMoEConfig, layer_idx: int, is_mtp_layer: bool = F """ Check if the current layer is a MoE layer. """ - return not is_mtp_layer and hasattr(config, "is_moe_layer") and config.is_moe_layer[layer_idx] + # The MTP layer of K-EXAONE is always dense. + if is_mtp_layer: + return False + + if hasattr(config, "mlp_layer_types") and config.mlp_layer_types is not None: + return config.mlp_layer_types[layer_idx] == "sparse" + + # For backward compatibility, older K-EXAONE checkpoints do not include `mlp_layer_types`. + if hasattr(config, "is_moe_layer") and config.is_moe_layer is not None: + return config.is_moe_layer[layer_idx] + + raise ValueError( + "Invalid configuration: Neither `mlp_layer_types` nor `is_moe_layer` found in config. " + "Please check if the checkpoint and config are compatible with ExaoneMoEConfig." + ) def enable_attn_allreduce(mapping: Mapping): diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 5afd1d36947..8b785e18e92 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -811,7 +811,7 @@ def apply_rope(self, q: torch.Tensor, k: Optional[torch.Tensor], def apply_qk_norm(self, q, k): raise NotImplementedError( - f"QK norm is not implemented for {self.__class__.__name__}." + f"QK norm is not implemented for {self.__class__.__name__}. " "Please override the `apply_qk_norm` method in the subclass.") @@ -959,7 +959,7 @@ def __init__( self) self.register_to_config = True - # only support one kind of sparse attention, dsa now. + # Currently only DSA sparse attention is supported. if config is not None and config.sparse_attention_config is not None and config.sparse_attention_config.algorithm == "dsa": self.is_dsa = True else: @@ -982,7 +982,7 @@ def __init__( dp_size = tp_size tp_size = 1 if self.mapping.has_cp_ulysses(): - raise NotImplementedError("MLA doesn't support CP Ulyssees yet") + raise NotImplementedError("MLA doesn't support CP Ulysses yet") if self.mapping.cp_size > 1: assert self.mapping.has_cp_helix( ), f"CP type must be HELIX for MLA, but got {self.mapping.cp_config['cp_type']}." @@ -1360,13 +1360,13 @@ def forward_impl(self, output: torch.Tensor, latent_cache_gen: Optional[torch.Tensor] = None) -> None: """ - Forward pass for the MLA module. + Forward pass for the MLA module. Writes result into output tensor in-place. Args: position_ids (Optional[torch.IntTensor]): The position IDs. hidden_states (torch.Tensor): The hidden states. attn_metadata (AttentionMetadata): The attention metadata. - output (torch.Tensor): Pre-allocated output tensor, written in-place. + output (torch.Tensor): The output tensor to write results into. latent_cache_gen (Optional[torch.Tensor]): The latent cache used in generation. """ # split q, k, v into context and gen batches @@ -1464,12 +1464,13 @@ def forward_impl_with_dsa(self, position_ids: Optional[torch.Tensor], output: torch.Tensor) -> None: """ Forward pass for the MLA module with DSA (always in MQA mode). + Writes result into output tensor in-place. Args: position_ids (Optional[torch.IntTensor]): The position IDs. hidden_states (torch.Tensor): The hidden states. attn_metadata (AttentionMetadata): The attention metadata. - output (torch.Tensor): Pre-allocated output tensor, written in-place. + output (torch.Tensor): The output tensor to write results into. """ assert self.mqa is not None, "DSA is only supported in MQA mode" # split q, k, v into context and gen batches @@ -1800,7 +1801,7 @@ def forward_context_with_chunked_prefill( # currently we assume that the chunk size is the same as the max_num_tokens chunked_loop_num = attn_metadata.chunked_loop_num - # [toal_token_q, num_heads, 2] -> [toal_token_q, num_heads] float2 + # [total_token_q, num_heads, 2] -> [total_token_q, num_heads] float2 self.softmax_stats_tensor = torch.empty( (attn_metadata.num_ctx_tokens, self.num_heads_tp, 2), dtype=torch.float, diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 99a6847461c..f3ea6e9a096 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -56,7 +56,7 @@ def get_moe_cls( return TRTLLMGenFusedMoE else: logger.warning( - "TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8. " + "TRTLLMGenFusedMoE only supports fp8_block_scales, nvfp4, w4a16_mxfp4, w4a8_nvfp4_fp8, w4a8_mxfp4_fp8, and w4a8_mxfp4_mxfp8. " f"Check out details in quant_config: {quant_config}. Using CutlassFusedMoE instead." ) return CutlassFusedMoE @@ -140,7 +140,7 @@ def create_moe_backend( assert moe_cls in [ WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE, CuteDslFusedMoE, DeepGemmFusedMoE - ], "MoE Load Balance is only supported in WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE and CuteDslFusedMoE, and DeepGemmFusedMoE." + ], "MoE Load Balance is only supported in WideEPMoE, CutlassFusedMoE, TRTLLMGenFusedMoE, CuteDslFusedMoE, and DeepGemmFusedMoE." if bias: assert moe_cls in [CutlassFusedMoE, TritonFusedMoE, TRTLLMGenFusedMoE @@ -371,14 +371,14 @@ def create_moe( activation_type=activation_type, ) else: - # Check if this is a TRTLLM backend request that fallback to CutlassFusedMoE + # Check if this is a TRTLLM or CUTEDSL backend request that fell back to CutlassFusedMoE requested_backend = model_config.moe_backend.upper() if requested_backend in ("TRTLLM", "CUTEDSL") and moe_cls == CutlassFusedMoE: - # Workaround for test cases where TRTLLM backend fallbacks to CutlassFusedMoE due to quant_config incompatibility + # Workaround for test cases where TRTLLM backend falls back to CutlassFusedMoE due to quant_config incompatibility # Log warning and continue with the fallback backend logger.warning( - f"ENABLE_CONFIGURABLE_MOE is set but TRTLLM backend fallback to {moe_cls.__name__} due to quant_config. " + f"ENABLE_CONFIGURABLE_MOE is set but {requested_backend} backend fell back to {moe_cls.__name__} due to quant_config. " f"ConfigurableMoE only supports TRTLLMGenFusedMoE and CuteDslFusedMoE backends. " f"Continuing with legacy MoE backend {moe_cls.__name__}.") else: diff --git a/tensorrt_llm/_torch/modules/fused_moe/routing.py b/tensorrt_llm/_torch/modules/fused_moe/routing.py index ebc3cdf03a8..69498c96cfc 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/routing.py +++ b/tensorrt_llm/_torch/modules/fused_moe/routing.py @@ -75,7 +75,7 @@ def precompute_common_perfect_router_logits(num_experts: int, 5120, 6144, 7168, - 8192 # Powers of 2 and common sizes + 8192 # Common sizes ] print( diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index 927acdcd289..4e6b69954de 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -91,7 +91,7 @@ def split_dim(cls, mode): return 1 if mode == cls.ROW else 0 # Helper to shard the corresponding per-channel activation scales - # Which shard along the dimension orthogonal to the weights + # which are sharded along the dimension orthogonal to the weights @classmethod def flip(cls, mode): return cls.ROW if mode == cls.COLUMN else cls.COLUMN @@ -190,7 +190,7 @@ def load_weights_vanilla_helper(module: Linear, if weight is not None: if module.has_weight_only_quant: - # NOTE: without the preprocess during the runtime, the gemm output nan's. in order to use the preprocess_weights_for_mixed_gemm + # NOTE: without the preprocess during the runtime, the gemm outputs NaNs. In order to use the preprocess_weights_for_mixed_gemm # we need to cast the weight to int8 first. activation_dtype = torch.float8_e4m3fn if module.has_w4a8_awq else torch.float16 weight_dtype, _ = get_weight_dtype_and_id(module) @@ -571,7 +571,7 @@ def create_weights(self, module: Linear, in_features: int, # K, V scales for NVFP4 KV cache module.kv_scales = Parameter(torch.ones(3, dtype=torch.float32), requires_grad=False) - # K, V scales for NVFP4 KV cache + # Inverse K, V scales for NVFP4 KV cache module.inv_kv_scales = Parameter(torch.ones(3, dtype=torch.float32), requires_grad=False) if bias: @@ -919,7 +919,7 @@ def apply(self, module: Linear, input: torch.Tensor, return output def _get_scale_name(self, weights: List[Dict]): - # `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe. + # `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe. # Actually they hold identical values of data_amax / 448. scale_name = "weight_scale_inv" if scale_name not in weights[0]: @@ -1065,7 +1065,7 @@ def apply(self, module: Linear, input: torch.Tensor, return output def _get_scale_name(self, weights: List[Dict]): - # `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe. + # `weight_scale_inv` for DS recipe and `weight_scale` for ModelOpt recipe. # Actually they hold identical values of data_amax / 448. for w in weights: if "weight_scale_inv" in w: @@ -1230,7 +1230,7 @@ def create_weights(self, module: Linear, in_features: int, # K, V scales for NVFP4 KV cache module.kv_scales = Parameter(torch.ones(3, dtype=torch.float32), requires_grad=False) - # K, V scales for NVFP4 KV cache + # Inverse K, V scales for NVFP4 KV cache module.inv_kv_scales = Parameter(torch.ones(3, dtype=torch.float32), requires_grad=False) @@ -1531,14 +1531,8 @@ def load_weights_fused_gate_up_linear(self, module: Linear, copy_weight(module.pre_quant_scale, pre_quant_scale) def post_load_weights(self, module: Linear): + """Pad weight and weight_scale tensors to meet torch trtllm NVFP4 GEMM alignment requirements.""" super().post_load_weights(module) - """ - Pad weight and weight_scale tensors to meet torch trtllm NVFP4 GEMM alignment requirements. - - Args: - row_alignment: Required row alignment (default: 32) - col_alignment: Required column alignment (default: 16) - """ row_alignment, col_alignment = 32, 16 row_pad_size = (row_alignment - module.weight.size(0)) % row_alignment col_pad_size = (col_alignment - module.weight.size(1)) % col_alignment @@ -1682,7 +1676,7 @@ def load_weight_scales( weight_scale_2 = w["weight_scale_2"][...] else: assert weight_scale_2 == w["weight_scale_2"][...], ( - f"The weight_scale_2 should be same for all the weights: {weight_scale_2} vs. {w['weight_scale_2']}*6" + f"The weight_scale_2 should be same for all the weights: {weight_scale_2} vs. {w['weight_scale_2']}" ) # TODO: ModelOpt's o_proj.weight_scale_2 is bfloat16, which should be float32 @@ -2195,7 +2189,7 @@ def apply(self, module: Linear, input: torch.Tensor, 1. multiply pre_quant_scale to input 2. quantize input to fp8 using input_scale 3. unpack_weights and multiply by weight_scales (int4 -> fp16) - 4. divied by weight_scale_2 (fp16 -> fp8 to allow gemm in fp8). + 4. divided by weight_scale_2 (fp16 -> fp8 to allow gemm in fp8). 5. apply gemm in fp8. 6. rescale using alpha which is input_scale * weight_scale_2 """ @@ -2731,7 +2725,7 @@ def load_weights(self, weight_mode = self.weights_loading_config.weight_mode if not isinstance(self.quant_method, UnquantizedLinearMethod): - assert allow_partial_loading is False, "allow_partial_loading is only supported for non-unquantized linear methods now" + assert allow_partial_loading is False, "allow_partial_loading is only supported for unquantized linear methods now" self.quant_method.load_weights( self, weights, diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index be44575a5c5..1bedaffccf3 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -362,7 +362,7 @@ def __init__( # 3) The model configuration is not loaded until the model engine # is initialized. # - # NOTE: This can simplified by decoupling the model config loading and + # NOTE: This can be simplified by decoupling the model config loading and # the model engine. self.attn_metadata = None self.iter_states = {} @@ -904,8 +904,8 @@ def _capture_piecewise_cuda_graphs(self, resource_manager: ResourceManager): gc.collect() torch.cuda.empty_cache() - # When using piecewise cuda graph, the logits may suffer severe memory faction problem. - # When the num of requests is growing, the block allocated by torch cannot be reused. + # When using piecewise cuda graph, the logits may suffer severe memory fragmentation problem. + # As the number of requests grows, the blocks allocated by torch cannot be reused. # So after piecewise cuda graph capture, a request with most requests is triggered to make # sure that large enough blocks are allocated and can be correctly reused. for num_tokens in piecewise_cuda_graph_num_tokens: @@ -1389,14 +1389,14 @@ def _release_cuda_graphs(self): def get_max_num_sequences(self) -> int: """ - Return the maximum number of sequences that the model supports. PyExecutor need this to compute max_num_active_requests + Return the maximum number of sequences that the model supports. PyExecutor needs this to compute max_num_active_requests """ num_batches = self.mapping.pp_size return num_batches * self.batch_size def _preprocess_inputs(self, inputs: Dict[str, Any]): """ - Make some changes to the device inputs and avoid block the async data transfer + Make some changes to the device inputs and avoid blocking the async data transfer """ if self.enable_spec_decode and not self._disable_overlap_scheduler: # When enabling overlap scheduler, the kv cache for draft tokens will @@ -1554,7 +1554,7 @@ def get_padded_piecewise_tokens(tokens): return padded_num_tokens, True, None else: logger.debug( - f"Picewise cudagraph cannot be used with {total_num_tokens} tokens, {num_ctx_requests} context requests" + f"Piecewise CUDA graph cannot be used with {total_num_tokens} tokens, {num_ctx_requests} context requests" ) return total_num_tokens, False, None diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 928a95de66f..1f4049ebddb 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -132,7 +132,7 @@ class BatchStatePP(BatchState): class AsyncTransferManager: """ - Handle asynchronous transfer or KV cache after a request has completed. + Handle asynchronous transfer of KV cache after a request has completed. When running with both the KV cache transceiver and the KV cache connector, we must ensure that BOTH transfers (if any) are completed before we can release the KV cache blocks. The AsyncTransferManager has a few key responsibilities: 1. Track requests in transfer. @@ -171,7 +171,7 @@ def __init__(self, # Mapping of request id to the LlmRequest self._requests_in_transfer: Dict[int, LlmRequest] = dict() - # Mapping of request id to the the request metadata + # Mapping of request id to the request metadata self._request_transfer_metadata: Dict[ int, self.RequestTransferMetadata] = dict() @@ -463,12 +463,12 @@ def on_detected(): batch_wait_timeout_ms=self.batch_wait_timeout_ms, ) # When overlap scheduler is enabled then when starting to handle a new prompt, - # sample_async is called twice before the first call to update_requests: - # - 1st time as a context request that handles on the 1st generated token - # - 2nd time as a generation request that handles on the 2nd generated token. + # _sample_async is called twice before the first call to update_requests: + # - 1st time as a context request that operates on the 1st generated token + # - 2nd time as a generation request that operates on the 2nd generated token. # and only after these two calls the sampler's update_request method is called. # So in a sampler that works by the expected flow of handling the logits in - # sample_async, every update_request doesn't handle the newest token, but one + # _sample_async, every update_request doesn't handle the newest token, but one # before it. Since all these calls work on the same request object, then its # logits storage contains the logits of both the token update_requests should work # on, and also its next token. Thus, excluding the last generation logits from any @@ -648,7 +648,7 @@ def _set_global_steady_clock_offset(self): def __enter__(self): return self - def __exit__(self): + def __exit__(self, exc_type, exc_val, exc_tb): self.shutdown() def enqueue_requests( @@ -672,7 +672,7 @@ def await_responses( timeout: Optional[datetime.timedelta] = None, ) -> Union[List[List[LlmResponse]], List[LlmResponse]]: """ - Await for ready responses + Await ready responses Args: id (Optional[Union[List[int], int]]): Request id timeout (Optional[datetime.timedelta]): The maximum time to wait for new responses @@ -1317,8 +1317,8 @@ def _executor_loop_pp(self): self.resource_manager.prepare_resources(scheduled_batch) - # The generation requests that are do not have batch_idx, - # needs to be in front of the batch due to the assumptions + # The generation requests that do not have batch_idx + # need to be in front of the batch due to the assumptions # made in model_engine.py::_forward_step. This is only important # for disaggregated serving. For non-disaggregated serving, # the generation requests always have batch_idx. @@ -1621,7 +1621,7 @@ def _can_queue(self, scheduled_batch): # can_queue_this_rank is for case that the batch is not empty on this rank, but empty on other ranks # For bs == 1, we cannot pad dummy request to make the batch non-empty since it will cause the batch size to be 2. - # 1 for dummy request, 1 for the to complete but haven't updated request. + # 1 for dummy request, 1 for the yet-to-complete but not-yet-updated request. if self.enable_attention_dp: tp_batch_sizes = self.dist.tp_allgather(scheduled_batch.batch_size) can_queue = 0 not in tp_batch_sizes @@ -2092,8 +2092,8 @@ def _executor_loop_overlap(self): should_process_previous_batch = can_queue or not can_queue_this_rank if can_queue: - # The generation requests that are do not have batch_idx, - # needs to be in front of the batch due to the assumptions + # The generation requests that do not have batch_idx + # need to be in front of the batch due to the assumptions # made in model_engine.py::_forward_step. This is only important # for disaggregated serving. For non-disaggregated serving, # the generation requests always have batch_idx. @@ -2153,9 +2153,6 @@ def _executor_loop_overlap(self): if self.previous_batch is not None and should_process_previous_batch: self._update_requests(self.previous_batch.sample_state) - self.perf_manager.compute_batch_gpu_times( - self.previous_batch.all_requests) - self._send_kv_async(self.previous_batch.all_requests) if self.drafter is not None and self.use_spec_decode and should_process_previous_batch: @@ -2177,11 +2174,6 @@ def _executor_loop_overlap(self): sample_state = self._sample_async( scheduled_batch, batch_outputs) - self.perf_manager.save_timing_to_requests( - scheduled_batch.all_requests(), gpu_forward_start, - gpu_forward_end, gpu_sample_end, fwd_timing.start_time, - fwd_timing.end_time, sample_timing.start_time, - sample_timing.end_time) assert sample_state is not None, "Sampling failed" # Handle guided decoder errors after _sample_async to avoid state conflicts. @@ -2193,6 +2185,8 @@ def _executor_loop_overlap(self): if self.previous_batch is not None and should_process_previous_batch: self._process_previous_batch() + self.perf_manager.compute_batch_gpu_times( + self.previous_batch.all_requests) else: self._enqueue_responses([]) @@ -2204,6 +2198,11 @@ def _executor_loop_overlap(self): scheduled_batch.generation_requests) if can_queue: + self.perf_manager.save_timing_to_requests( + scheduled_batch.all_requests(), gpu_forward_start, + gpu_forward_end, gpu_sample_end, fwd_timing.start_time, + fwd_timing.end_time, sample_timing.start_time, + sample_timing.end_time) if self.enable_iter_perf_stats: iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ 'num_ctx_tokens'] @@ -2735,7 +2734,7 @@ def flag_if_kv_transfer_timed_out(req: LlmRequest, type: str) -> None: def _check_disagg_ctx_schedulable_status(self, new_requests: List[LlmRequest]): """ - In context-first mode, context requests are scheduable immediately, + In context-first mode, context requests are schedulable immediately, otherwise, we need to check if context requests are ready to be scheduled by querying kv cache transceiver """ if not self.kv_cache_transceiver: @@ -3482,10 +3481,10 @@ def key_has_response(): self.responses.pop(id) return response - def _terminate_requests(self, requests_to_pause): + def _terminate_requests(self, requests_to_terminate): # todo: support work with self.inflight_req_ids. - # Currently, self.inflight_req_ids is not. - for req in requests_to_pause: + # Currently, self.inflight_req_ids is not updated. + for req in requests_to_terminate: self._terminate_request(req) def _pause_requests(self, requests_to_pause): diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 9306523d6b5..2d826d8c9a0 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -130,7 +130,7 @@ def get_pp_layers( assert sum(layer_mask) == num_layers, ( f"The number of enabled layers in layer_mask ({sum(layer_mask)}) " f"must match the number of layers ({num_layers}) " - f"in KV cache manager, but get layer_mask: {layer_mask}") + f"in KV cache manager, but got layer_mask: {layer_mask}") total_num_layers = len(layer_mask) pp_layers = mapping.pp_layers(total_num_layers) if layer_mask is not None: @@ -735,7 +735,7 @@ def add_dummy_requests( for _ in range(max_num_draft_tokens): draft_kv_cache_manager.impl.add_token(req_id) - # TODO: Planning to get dummy_data from each model. Before that, we need to add dummy mrop_config to the request here. + # TODO: Planning to get dummy_data from each model. Before that, we need to add dummy mrope_config to the request here. if use_mrope: dummy_mrope_position_ids = torch.arange( 0, token_num, dtype=torch.int32).expand(3, 1, -1).clone() @@ -927,7 +927,7 @@ def get_max_atten_window_upper_bound(self, blocks_in_primary_pool, max_atten_window_upper_bound = max_token_num - sink_bubble_len if max_seq_len is not None and max_seq_len > max_atten_window_upper_bound and max_beam_width > 1: max_atten_window_upper_bound -= tokens_per_block - assert max_atten_window_upper_bound > 0, "Impossibe to fit in any sequence in kvCache" + assert max_atten_window_upper_bound > 0, "Impossible to fit in any sequence in kvCache" return max_atten_window_upper_bound def get_cache_indices(self, @@ -2106,7 +2106,7 @@ def release_resources(current_request: LlmRequest, f"Failed to resize capacity of draft KV cache for request {req.py_request_id} to {new_capacity} tokens for dummy request" ) - # TODO: Planning to get dummy_data from each model. Before that, we need to add dummy mrop_config to the request here. + # TODO: Planning to get dummy_data from each model. Before that, we need to add dummy mrope_config to the request here. if use_mrope: dummy_mrope_position_ids = torch.arange( 0, token_num, dtype=torch.int32).expand(3, 1, -1).clone() @@ -2208,7 +2208,7 @@ def get_layer_bytes_per_token(self, local_layer_idx: int, data_role: Role): if data_role in [Role.KEY_BLOCK_SCALE, Role.VALUE_BLOCK_SCALE]: assert self.dtype == DataType.NVFP4, "NVFP4 is the only supported dtype for block quant data roles" if data_role == Role.VALUE: - assert self.kv_cache_type != CacheTypeCpp.SELFKONLY, "SELFKONLY is the only supported cache type for value data role" + assert self.kv_cache_type != CacheTypeCpp.SELFKONLY, "VALUE data role is not supported for SELFKONLY cache type" kv_factor = 1 else: raise ValueError(f"Invalid data role: {data_role}") diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 197680cf166..54477281374 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -2868,13 +2868,13 @@ def _get_logprobs_from_request( pin_memory: bool = True, preallocate_extra_steps: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: - pin_memory = pin_memory and prefer_pinned() - """Extract the logprobs from the request + """Extract the logprobs from the request. Returns: logprobs_tensor: A tensor of shape (beam_width, num_generated_tokens, num_logprobs) logprobs_indices_tensor: A tensor of shape (beam_width, num_generated_tokens, num_logprobs) """ + pin_memory = pin_memory and prefer_pinned() num_generated_tokens = request.max_beam_num_tokens - request.py_prompt_len assert request.py_num_logprobs == 0, ( "Beam search only supports returning the sampled logprob per token" @@ -3056,10 +3056,9 @@ def _finalize_beam( ) -> None: """Update the request with the corrected tokens and logprobs for each beam. - arguments: + Args: request: The request to update beam_history: The beam history used to update the request - finish_reasons: The finish reasons to use to check if the beam is finished (Shape: (beam_width,)) """ beam_width = request.sampling_config.beam_width @@ -3468,7 +3467,7 @@ def _apply_embedding_bias( """ # NB: Unfortunately, Torch provides no combination of torch.index_select (similar to # torch.Tensor.gather -- allows one-to-many mapping) and addition, analogous to how - # torch.Tensor.scatter_add_ (and it's variant torch.Tensor.index_add_ -- allows + # torch.Tensor.scatter_add_ (and its variant torch.Tensor.index_add_ -- allows # many-to-one mapping) combine addition with torch.Tensor.scatter_. # # Notwithstanding the previous point, there are two options: diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 5084abbd274..a09e6075258 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -296,7 +296,7 @@ class Eagle3OneModelSpecMetadata(SpecMetadata): max_num_tokens: int = 0 # The dtype of the hidden states dtype: torch.dtype = torch.bfloat16 - # The index of the batche inputs + # The index of the batch inputs batch_indices_cuda: Optional[torch.Tensor] = None def __post_init__(self): @@ -335,7 +335,7 @@ def is_layer_capture(self, layer_id: int): def prepare(self): assert self.request_ids is not None - # update batch indeices + # update batch indices num_seqs = len(self.request_ids) batch_indices = torch.arange(num_seqs, dtype=torch.int, @@ -505,7 +505,7 @@ def forward(self, attn_metadata._seq_lens[:batch_size].fill_(1) attn_metadata._seq_lens_cuda[:batch_size].fill_(1) attn_metadata.on_update() - # cannot run generation if their is no kv cache + # cannot run generation if there is no kv cache if inputs["attn_metadata"].kv_cache_manager is not None: attn_metadata.host_request_types[:attn_metadata. num_contexts].fill_(1) diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 49c0ff28f11..f8a59c93fe9 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -119,14 +119,14 @@ class MTPSpecMetadata(SpecMetadata): mtp_hidden_states_manager: Optional[MTPHiddenStatesManager] = None # The slot ids for each request. slot_ids: Optional[torch.Tensor] = None - # The index of the batche inputs + # The index of the batch inputs batch_indices_cuda: Optional[torch.Tensor] = None # The number of sequences for speculative model/layer of different rank _all_rank_num_seqs: Optional[List[int]] = None # This is used for attention dp in the MTP Eagle worker. The numbers of input # tokens varies between the 1st draft forward and subsequent ones. To support # CUDA graph, we use this tensor to store the number of input tokens for the - # subsequence draft forward. + # subsequent draft forward. subseq_all_rank_num_tokens: Optional[List[int]] = None # Optional suffix automaton manager for MTP+SA speculative decoding sa_manager: Optional[SuffixAutomatonManager] = None @@ -173,7 +173,7 @@ def all_rank_num_seqs(self, value: List[int]): def prepare(self): assert self.request_ids is not None num_seqs = len(self.request_ids) - # update batch indeices + # update batch indices batch_indices = torch.arange(num_seqs, dtype=torch.int, device='cpu', @@ -310,7 +310,7 @@ def forward( - hidden states: H_A, H_B, H_C, H_D Draft model: MTP1: - # For context request, prompt[1:] + new generated goloden token is the input. + # For context request, prompt[1:] + new generated golden token is the input. - input tokens: BCDE - input hidden states: H_A, H_B, H_C, H_D # '()' means historical KV cache @@ -527,14 +527,14 @@ def update_mtp_hidden_states( attn_metadata: AttentionMetadata, ): ''' - Update the past hidden states and past tokens in spec_metadata base on + Update the past hidden states and past tokens in spec_metadata based on the newly accepted tokens and historical hidden states. - These past hidden states and past tokens will be use in MTP module. + These past hidden states and past tokens will be used in MTP module. Args: input_ids: torch.IntTensor [num_tokens] - The input ids of all requests. Flatten. + The input ids of all requests. Flattened. hidden_states: torch.Tensor [num_tokens, hidden_size] @@ -692,7 +692,7 @@ def sample_and_accept_draft_tokens( Args: input_ids: torch.IntTensor [num_tokens] - The input ids of all requests. Flatten. + The input ids of all requests. Flattened. logits: torch.Tensor [num_tokens, vocab_size] @@ -897,17 +897,17 @@ def prepare_drafter_inputs( attn_metadata: AttentionMetadata, ): ''' - Parepare the input of the draft model. + Prepare the input of the draft model. Args: input_ids: torch.IntTensor [num_tokens] - The input ids of all requests. Flatten. + The input ids of all requests. Flattened. num_tokens = sum(all prompts) + num_generation * (mtp_num_modules + 1) position_ids: torch.IntTensor [1][num_tokens] - The position id of all requests. Flatten. + The position id of all requests. Flattened. hidden_states: torch.Tensor [num_tokens, hidden_size] @@ -930,12 +930,12 @@ def prepare_drafter_inputs( Returns: draft_inputs input_ids: torch.Tensor [num_tokens] - The new input ids of all requests. Flatten. + The new input ids of all requests. Flattened. num_tokens = sum(all prompts) + num_generation * (mtp_num_modules) position_ids: torch.Tensor - [1][[num_tokens]] - The new position ids of all requests. Flatten. + [1, num_tokens] + The new position ids of all requests. Flattened. Directly use the input position ids. hidden_states: torch.Tensor @@ -1253,7 +1253,7 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): -1, hidden_states_gathered.shape[-1]) else: raise ValueError( - f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported" + "In MTPEagleWorker.forward(), token_count > max_num_requests, which is not supported" ) logits = draft_model.mtp_layers[0].shared_head( padded_hidden_states, draft_model.lm_head, @@ -1284,7 +1284,7 @@ def prepare_position_ids_and_last_tokens(position_ids, attn_metadata): attn_metadata._seq_lens[:batch_size].fill_(1) attn_metadata._seq_lens_cuda[:batch_size].fill_(1) attn_metadata.on_update() - # cannot run generation if their is no kv cache + # cannot run generation if there is no kv cache has_kv_cache = inputs[ "attn_metadata"].kv_cache_manager is not None if has_kv_cache: diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index f72de46d751..2cd47787559 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -2,7 +2,7 @@ import gc import json import os -import signal # Added import +import signal import socket import subprocess # nosec B404 import sys @@ -56,7 +56,6 @@ def _signal_handler_cleanup_child(signum, frame): """Signal handler to clean up the child process.""" global _child_p_global if _child_p_global and _child_p_global.poll() is None: - # Using print for safety in signal handlers logger.info( f"Parent process (PID {os.getpid()}) received signal {signal.Signals(signum).name}. Terminating child process (PID {_child_p_global.pid})." ) @@ -952,7 +951,7 @@ def serve_encoder(model: str, host: str, port: int, log_level: str, """ logger.set_level(log_level) - # TODO: expose more argument progressivly + # TODO: expose more arguments progressively llm_args, _ = get_llm_args(model=model, max_batch_size=max_batch_size, max_num_tokens=max_num_tokens, @@ -1084,7 +1083,7 @@ def disaggregated_mpi_worker(config_file: Optional[str], log_level: str): if os.environ.get(DisaggLauncherEnvs. TLLM_DISAGG_RUN_REMOTE_MPI_SESSION_CLIENT) != "1": set_cuda_device() - # Importing mpi4py after setting CUDA device. This is needed to war an issue with mpi4py and CUDA + # Importing mpi4py after setting CUDA device. This is needed to work around an issue with mpi4py and CUDA from mpi4py.futures import MPICommExecutor from tensorrt_llm._utils import global_mpi_rank, mpi_rank, set_mpi_comm @@ -1117,7 +1116,7 @@ def disaggregated_mpi_worker(config_file: Optional[str], log_level: str): f"mpi_session is provided for LLM instance. Global MPI rank: {global_mpi_rank()}, sub-comm MPI rank: {mpi_rank()}" ) - # Leader ranks will start the trtllm-server using it's own server config + # Leader ranks will start the trtllm-server using its own server config # and start a RemoteMPISessionServer to accept MPI tasks if is_leader: os.environ[DisaggLauncherEnvs.TLLM_DISAGG_INSTANCE_IDX] = str( diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index a6cd3690148..b3d096fa6a6 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -328,7 +328,7 @@ def aget_stats(self, timeout: float) -> IterationResult: """ if self._iter_stats_result is None: print_colored( - "Iteration statistics are not available yet. To collect runtime statistics, please call get_stats_async() in async coroutine or the /metrics endpoint (if you're using trtllm-serve) AFTER prompts have been submitted.\n", + "Iteration statistics are not available yet. To collect runtime statistics, please call aget_stats() in async coroutine or the /metrics endpoint (if you're using trtllm-serve) AFTER prompts have been submitted.\n", "yellow") return empty_async_iterable() @@ -538,8 +538,8 @@ def create( use_worker=True) # For single-gpu case: - # Partition the workload to multiple process for streaming performance. - # While this requires uses to protect their entrypoint to + # Partition the workload to multiple processes for streaming performance. + # While this requires users to protect their entrypoint to # `if __name__ == "__main__":`. if not platform.system() == 'Windows': if orchestrator_is_rpc: diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index d9f8bf46a51..f56e560d03a 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -659,8 +659,8 @@ def get_stats(self, timeout: Optional[float] = 2) -> List[dict]: timeout (float, optional): Max wait time in seconds when retrieving stats from queue. Defaults to 2. Returns: - List[dict]: A list of runtime stats as dict. - e.g., ['{"cpuMemUsage": ..., "iter": 0, ...}', '{"cpuMemUsage": ..., "iter": 1, ...}'] + List[dict]: A list of runtime stats as dicts. + e.g., [{"cpuMemUsage": ..., "iter": 0, ...}, {"cpuMemUsage": ..., "iter": 1, ...}] ''' return self._executor.get_stats(timeout=timeout) @@ -717,7 +717,7 @@ def get_kv_cache_events_async(self, - set `enable_block_reuse` to True in the `KvCacheConfig`. Args: - timeout (float, optional): Max wait time in seconds when retrieving events from queue. . Defaults to 2. + timeout (float, optional): Max wait time in seconds when retrieving events from queue. Defaults to 2. Returns: tensorrt_llm.executor.result.IterationResult: An async iterable object containing runtime events. @@ -761,7 +761,7 @@ def _prepare_sampling_params( f"The sampling_params must be type SamplingParams or None, but got {type(sampling_params)}" ) - # auto enabled context and/or generation logits flags, as they are required by logprob computation for TRT backend. + # auto enable context and/or generation logits flags, as they are required by logprob computation for TRT backend. if self.args.backend not in ["pytorch", "_autodeploy"]: if sampling_params.prompt_logprobs and not sampling_params.return_context_logits: sampling_params.return_context_logits = True @@ -793,11 +793,11 @@ def _check_arguments(self, prompt_len: int, query_len: int, build_config = self.args.build_config - built_enging_cfg_file = Path(self.args.model) / 'config.json' - with open(built_enging_cfg_file) as f: - built_enging_cfg = json.load(f) - max_seq_len = built_enging_cfg['build_config'][ - 'max_seq_len'] if 'build_config' in built_enging_cfg else build_config.max_seq_len + built_engine_cfg_file = Path(self.args.model) / 'config.json' + with open(built_engine_cfg_file) as f: + built_engine_cfg = json.load(f) + max_seq_len = built_engine_cfg['build_config'][ + 'max_seq_len'] if 'build_config' in built_engine_cfg else build_config.max_seq_len # TODO: Remove this check and left the request verification to cpp runtime if (not self.args.enable_chunked_prefill) and ( diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 72e478d0c22..75c396867d6 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -213,7 +213,7 @@ class BaseSparseAttentionConfig(StrictBaseModel): def supports_backend(self, backend: str) -> bool: """ - Override if the speculation algorithm does not support + Override if the sparse attention algorithm does not support a subset of the possible backends. """ return True @@ -237,9 +237,9 @@ class RocketSparseAttentionConfig(BaseSparseAttentionConfig): """ algorithm: Literal["rocket"] = "rocket" window_size: Optional[int] = Field( - default=32, description="The window size for snap KV.") + default=32, description="The window size for RocketKV.") kernel_size: Optional[int] = Field( - default=63, description="The kernel size for snap KV.") + default=63, description="The kernel size for RocketKV.") topr: Optional[Union[int, float]] = Field(default=128, description="Top-r") topk: Optional[int] = Field(default=64, description="Top-k") prompt_budget: Optional[int] = Field(default=2048, @@ -636,7 +636,7 @@ class _ModelFormatKind(Enum): class DecodingBaseConfig(StrictBaseModel): max_draft_len: Optional[NonNegativeInt] = Field( - default=None, description="The number of drafter layers.") + default=None, description="The maximum number of draft tokens.") max_total_draft_tokens: Optional[int] = Field( default=None, @@ -651,7 +651,7 @@ class DecodingBaseConfig(StrictBaseModel): validation_alias=AliasChoices("speculative_model", "speculative_model_dir"), description= - "The speculative (draft) model. Accepts either (1) a HuggingFace Hub model ID (e.g. 'yuhuili/EAGLE3-LLaMA3.1-Instruct-8B')," + "The speculative (draft) model. Accepts either (1) a HuggingFace Hub model ID (e.g. 'yuhuili/EAGLE3-LLaMA3.1-Instruct-8B'), " "which will be automatically downloaded, or (2) a local filesystem path to a downloaded model directory." ) @@ -934,7 +934,7 @@ def validate_eagle_config(self) -> 'EagleDecodingConfig': num_eagle_layers_from_choices = self.check_eagle_choices() if num_eagle_layers_from_choices != self.num_eagle_layers: logger.warning( - f"Base on the input choices, reset the num_eagle_layers(max_draft_len) from {self.num_eagle_layers} to {num_eagle_layers_from_choices}" + f"Based on the input choices, reset the num_eagle_layers(max_draft_len) from {self.num_eagle_layers} to {num_eagle_layers_from_choices}" ) self.num_eagle_layers = num_eagle_layers_from_choices self.max_draft_len = num_eagle_layers_from_choices @@ -1427,7 +1427,7 @@ def validate_ray_placement(self) -> 'RayPlacementConfig': class PybindMirror(ABC): ''' A class containing the utilities for mirroring Python classes to - pybinding classes. + pybind classes. ''' @abstractmethod @@ -1857,7 +1857,7 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): secondary_offload_min_priority: Optional[int] = Field( default=None, description= - "Only blocks with priority > mSecondaryOfflineMinPriority can be offloaded to secondary memory." + "Only blocks with priority > secondary_offload_min_priority can be offloaded to secondary memory." ) event_buffer_max_size: int = Field( default=0, @@ -2184,16 +2184,16 @@ class BaseLlmArgs(StrictBaseModel): moe_cluster_parallel_size: Optional[int] = Field( default=None, - description="The cluster parallel size for MoE models's expert weights.", + description="The cluster parallel size for MoE model's expert weights.", status="beta") moe_tensor_parallel_size: Optional[int] = Field( default=None, - description="The tensor parallel size for MoE models's expert weights.") + description="The tensor parallel size for MoE model's expert weights.") moe_expert_parallel_size: Optional[int] = Field( default=None, - description="The expert parallel size for MoE models's expert weights.") + description="The expert parallel size for MoE model's expert weights.") enable_attention_dp: bool = Field( default=False, @@ -2708,7 +2708,7 @@ def validate_speculative_config(self): elif isinstance(self.speculative_config, EagleDecodingConfig): assert self.speculative_config.max_draft_len > 0 - assert self.speculative_config.speculative_model is not None, "EAGLE3 draft model must be specified." + assert self.speculative_config.speculative_model is not None, "EAGLE draft model must be specified." self.build_config.max_draft_len = self.speculative_config.max_draft_len self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.EAGLE eagle_config = _EagleConfig( @@ -2958,7 +2958,7 @@ class TorchLlmArgs(BaseLlmArgs): garbage_collection_gen0_threshold: int = Field( default=20000, description= - "Threshold for Python garbage collection of generation 0 objects." + "Threshold for Python garbage collection of generation 0 objects. " "Lower values trigger more frequent garbage collection.", status="beta") @@ -3042,7 +3042,7 @@ class TorchLlmArgs(BaseLlmArgs): ge=0, le=1, description= - "Token accumulation threshold ratio for batch scheduling optimization. If greater than 0, the scheduler will accumulate requests locally until the total token count reaches batch_wait_max_tokens_ratio * max_num_tokens. This mechanism enhances GPU utilization efficiency by ensuring adequate batch sizes.If 0 disables token-based batching delays.", + "Token accumulation threshold ratio for batch scheduling optimization. If greater than 0, the scheduler will accumulate requests locally until the total token count reaches batch_wait_max_tokens_ratio * max_num_tokens. This mechanism enhances GPU utilization efficiency by ensuring adequate batch sizes. If 0, disables token-based batching delays.", status="prototype") torch_compile_config: Optional[TorchCompileConfig] = Field( @@ -3135,7 +3135,7 @@ class TorchLlmArgs(BaseLlmArgs): ray_worker_extension_cls: Optional[str] = Field( default=None, - description="The full worker extension class name including module path." + description="The full worker extension class name including module path. " "Allows users to extend the functions of the RayGPUWorker class.", status="prototype") @@ -3155,7 +3155,7 @@ class TorchLlmArgs(BaseLlmArgs): enable_sleep: bool = Field( default=False, description= - "Enable LLM sleep feature. Sleep feature requires extra setup that may slowdown model loading." + "Enable LLM sleep feature. Sleep feature requires extra setup that may slow down model loading. " "Only enable it if you intend to use this feature.", status="prototype") diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index deec651bfef..ae0e076800b 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -92,7 +92,7 @@ class _ModelRuntimeContext: @property def model_arch(self) -> str: - # "LlaMACausalForLM" or "OPTForCausalLM" and so on + # "LlamaForCausalLM" or "OPTForCausalLM" and so on return self.engine.config.pretrained_config['architecture'] @@ -349,7 +349,7 @@ def _update_from_hf_quant_config(self) -> bool: hf_quant_algo = QuantAlgo(hf_quant_algo) if quant_config.quant_algo is None: logger.info( - f"Setting quant_algo={hf_quant_algo} form HF quant config." + f"Setting quant_algo={hf_quant_algo} from HF quant config." ) quant_config.quant_algo = hf_quant_algo elif quant_config.quant_algo != hf_quant_algo: @@ -372,7 +372,7 @@ def _update_from_hf_quant_config(self) -> bool: quant_config.kv_cache_quant_algo = explicit_kv_cache_quant_algo elif quant_config.kv_cache_quant_algo is None: logger.info( - f"Setting kv_cache_quant_algo={hf_kv_cache_quant_algo} form HF quant config." + f"Setting kv_cache_quant_algo={hf_kv_cache_quant_algo} from HF quant config." ) quant_config.kv_cache_quant_algo = hf_kv_cache_quant_algo elif quant_config.kv_cache_quant_algo != hf_kv_cache_quant_algo: @@ -536,7 +536,7 @@ def _load_model_from_hf(self): mapping=self.mapping, quant_config=self.llm_args.quant_config, load_model_on_cpu= - True, # TODO:TRTLLM-195 to enhance the weights loading memory usage and chose best location + True, # TODO:TRTLLM-195 to enhance the weights loading memory usage and choose best location trust_remote_code=self.llm_args.trust_remote_code, speculative_model_dir=self._speculative_model_dir, speculative_config=self.llm_args.speculative_config @@ -590,7 +590,7 @@ def _build_engine(self): copied_build_config = self.build_config.model_copy(deep=True) copied_build_config.update_kv_cache_type(self._model_info.architecture) - assert self.model is not None, "model is loaded yet." + assert self.model is not None, "model has not been loaded yet." self._engine = build(self.model, copied_build_config) self.mapping = self.model.config.mapping @@ -636,7 +636,7 @@ def load_hf_generation_config( model_dir, **kwargs) except Exception as e: logger.warning( - f"Failed to load hf generation config from {model_dir}, encounter error: {e}" + f"Failed to load hf generation config from {model_dir}, encountered error: {e}" ) return None @@ -650,14 +650,14 @@ def load_hf_model_config( model_dir, trust_remote_code=trust_remote_code, **kwargs) except Exception as e: logger.warning( - f"Failed to load hf model config from {model_dir}, encounter error: {e}" + f"Failed to load hf model config from {model_dir}, encountered error: {e}" ) return None class CachedModelLoader: ''' - The CacheModelLoader is used to build the model in both single or multi-gpu, with cache might be enabled. + The CachedModelLoader is used to build the model in both single or multi-gpu, with optional caching. ''' def __init__( @@ -741,7 +741,7 @@ def __call__(self) -> Tuple[Path, Union[Path, None]]: if self.llm_args.quant_config.quant_algo is not None: logger.warning( - "QuantConfig for pytorch backend is ignored. You can load" + "QuantConfig for pytorch backend is ignored. You can load " "quantized model with hf_quant_config.json directly.") # Currently, this is to make updated quant_config visible by llm.args.quant_config # TODO: Unify the logics with those in tensorrt_llm/_torch/model_config.py diff --git a/tensorrt_llm/models/modeling_utils.py b/tensorrt_llm/models/modeling_utils.py index 812d67b0925..a85e959cd66 100644 --- a/tensorrt_llm/models/modeling_utils.py +++ b/tensorrt_llm/models/modeling_utils.py @@ -515,7 +515,7 @@ def to_layer_quant_config(self, config_file: str): config["quantized_layers"][moe_name] = quant_cfg else: assert quant_cfg == config["quantized_layers"][ - moe_name], "MoE module needs to have the same quantization format for non-rounter sub-modules" + moe_name], "MoE module needs to have the same quantization format for non-router sub-modules" self.quantization = LayerQuantConfig.model_validate(config) @@ -817,9 +817,9 @@ def prepare_inputs( mrope_rotary_cos_sin_size: int = None, ): '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the - ranges of the dimensions of when using TRT dynamic shapes. + ranges of the dimensions when using TRT dynamic shapes. - @return: a list contains values which can be fed into the self.forward() + @return: a list containing values which can be fed into the self.forward() ''' # Prepare inputs @@ -1545,7 +1545,7 @@ def share_embedding(model: PretrainedModel) -> PretrainedModel: return model -def set_fp8_context_fhma(model: PretrainedModel) -> PretrainedModel: +def set_fp8_context_fmha(model: PretrainedModel) -> PretrainedModel: for name, layer in model.named_modules(): if isinstance(layer, Attention) and hasattr( layer.dense, 'activation_scaling_factor'): @@ -1603,7 +1603,7 @@ def optimize_model( model = parallelize_embedding(model) if share_embedding_table: - # if share_embedding_table is enabled, only one copy of the embedding table is store in converted ckpt + # if share_embedding_table is enabled, only one copy of the embedding table is stored in converted ckpt # this pass is required to make lm_head.weight and vocab_embedding.weight point to the same tensor # however even if share_embedding_table is not enabled, trt would still only keep one copy of the table if the weights are identical model = share_embedding(model) @@ -1623,7 +1623,7 @@ def optimize_model( if use_lora: model = add_lora(model, max_lora_rank, with_dora=use_dora) if use_fp8_context_fmha: - model = set_fp8_context_fhma(model) + model = set_fp8_context_fmha(model) if fuse_fp4_quant: model = set_fuse_fp4_quant(model) if not use_lora and use_optimize_cross_qkv is True: @@ -1639,7 +1639,7 @@ def optimize_cross_qkv(model): So, add a new attribute 'kv' in the cross_attention layer. This might lead to additional memory cost on model size, but save the memory usage on runtime. - Currently, this function only detect the ColumnLinear and FP8Linear. It does not supports + Currently, this function only detects ColumnLinear and FP8Linear. It does not support other quantization now. """ for name, attn, layer in model.named_modules_with_parent(): diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index f84c0534914..f52fe64afe2 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -88,21 +88,22 @@ from .harmony_adapter import (HarmonyAdapter, get_harmony_adapter, maybe_transform_reasoning_effort) -# yapf: enale +# yapf: enable TIMEOUT_KEEP_ALIVE = 5 # seconds. class OpenAIServer: - def __init__(self, - generator: Union[LLM, MultimodalEncoder, VisualGen], - model: str, - tool_parser: Optional[str], - server_role: Optional[ServerRole], - metadata_server_cfg: MetadataServerConfig, - disagg_cluster_config: Optional[DisaggClusterConfig] = None, - multimodal_server_config: Optional[MultimodalServerConfig] = None, - chat_template: Optional[str] = None): + def __init__( + self, + generator: Union[LLM, MultimodalEncoder, VisualGen], + model: str, + tool_parser: Optional[str], + server_role: Optional[ServerRole], + metadata_server_cfg: MetadataServerConfig, + disagg_cluster_config: Optional[DisaggClusterConfig] = None, + multimodal_server_config: Optional[MultimodalServerConfig] = None, + chat_template: Optional[str] = None): self.generator = generator self._is_visual_gen = isinstance(generator, VisualGen) self.tool_parser = tool_parser @@ -139,7 +140,6 @@ def __init__(self, else: self._init_llm(chat_template) - @asynccontextmanager async def lifespan(app: FastAPI): if self.metadata_server is not None: @@ -152,19 +152,26 @@ async def lifespan(app: FastAPI): } # TODO: add more metadata # Register with ETCD using the existing key format - self.metadata_server.put(f"trtllm/{self.generator.llm_id}", metadata) + self.metadata_server.put(f"trtllm/{self.generator.llm_id}", + metadata) logger.info(f"trtllm/{self.generator.llm_id} is registered") if self.disagg_cluster_config: - self.disagg_cluster_storage = create_cluster_storage_client(self.disagg_cluster_config.cluster_uri, self.disagg_cluster_config.cluster_name) - self.disagg_cluster_worker= DisaggClusterWorker(self.server_role, self.host, self.port, self.disagg_cluster_config, self.disagg_cluster_storage) + self.disagg_cluster_storage = create_cluster_storage_client( + self.disagg_cluster_config.cluster_uri, + self.disagg_cluster_config.cluster_name) + self.disagg_cluster_worker = DisaggClusterWorker( + self.server_role, self.host, self.port, + self.disagg_cluster_config, self.disagg_cluster_storage) await self.disagg_cluster_worker.register_worker() # Start background iteration stats collector if metrics are enabled # The args for pytorch and autodeploy backend has attribute `enable_iter_perf_stats` while # tensorrt backend does not have this attribute but it always has iter stats enabled. - if self.metrics_collector and getattr(self.generator.args, "enable_iter_perf_stats", True): - self._iteration_stats_collector_task = asyncio.create_task(self._iteration_stats_collector_loop()) + if self.metrics_collector and getattr( + self.generator.args, "enable_iter_perf_stats", True): + self._iteration_stats_collector_task = asyncio.create_task( + self._iteration_stats_collector_loop()) logger.info("Started background iteration stats collector task") # terminate rank0 worker @@ -193,42 +200,50 @@ async def validation_exception_handler(_, exc): return JSONResponse(status_code=400, content={"error": str(exc)}) if self.server_role is ServerRole.VISUAL_GEN: - assert isinstance(self.generator, VisualGen), "generator must be a VisualGen for VISUAL_GEN server" + assert isinstance( + self.generator, VisualGen + ), "generator must be a VisualGen for VISUAL_GEN server" self.register_visual_gen_routes() elif self.server_role is ServerRole.MM_ENCODER: - assert isinstance(self.generator, MultimodalEncoder), "generator must be a MultimodalEncoder for multimodal encoder" + assert isinstance( + self.generator, MultimodalEncoder + ), "generator must be a MultimodalEncoder for multimodal encoder" self.register_mm_encoder_routes() else: self.register_routes() self.app.add_middleware(ServerArrivalTimeMiddleware) - def _init_visual_gen(self): self.processor = None self.model_config = None - self.media_storage_path = Path(os.getenv("TRTLLM_MEDIA_STORAGE_PATH", "/tmp/trtllm_generated")) # nosec B108 - self.media_storage_path.mkdir(exist_ok=True, parents= True) + self.media_storage_path = Path( + os.getenv("TRTLLM_MEDIA_STORAGE_PATH", + "/tmp/trtllm_generated")) # nosec B108 + self.media_storage_path.mkdir(exist_ok=True, parents=True) self.video_gen_tasks = {} - def _init_llm(self, chat_template: Optional[str] = None): self.tokenizer = self.generator.tokenizer hf_tokenizer_path = self.generator._hf_model_dir or self.tokenizer.tokenizer.name_or_path trust_remote_code = self.generator.args.trust_remote_code try: - self.processor = AutoProcessor.from_pretrained(hf_tokenizer_path, trust_remote_code=trust_remote_code) + self.processor = AutoProcessor.from_pretrained( + hf_tokenizer_path, trust_remote_code=trust_remote_code) except Exception: - logger.debug("Failed to load AutoProcessor or AutoConfig for %s", hf_tokenizer_path) + logger.debug("Failed to load AutoProcessor or AutoConfig for %s", + hf_tokenizer_path) self.processor = None # load model config try: from tensorrt_llm._torch.pyexecutor.config_utils import \ load_pretrained_config - self.model_config = load_pretrained_config(hf_tokenizer_path, - trust_remote_code=trust_remote_code, - checkpoint_format=getattr(self.generator.args, "checkpoint_format", None)) + self.model_config = load_pretrained_config( + hf_tokenizer_path, + trust_remote_code=trust_remote_code, + checkpoint_format=getattr(self.generator.args, + "checkpoint_format", None)) except Exception: logger.debug("Failed to load AutoConfig for %s", hf_tokenizer_path) self.model_config = None @@ -236,7 +251,9 @@ def _init_llm(self, chat_template: Optional[str] = None): self.chat_template = load_chat_template(chat_template) # Enable response storage for Responses API - self.enable_store = (len(os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) < 1) and not self.postproc_worker_enabled + self.enable_store = (len( + os.getenv("TRTLLM_RESPONSES_API_DISABLE_STORE", "")) + < 1) and not self.postproc_worker_enabled self.conversation_store = ConversationHistoryStore() @@ -248,7 +265,7 @@ def _init_llm(self, chat_template: Optional[str] = None): else: self.use_harmony = (self.model_config.model_type == "gpt_oss") - self.tool_call_id_type = "random" # default tool call id type is random + self.tool_call_id_type = "random" # default tool call id type is random if self.model_config is not None: if self.model_config.model_type == "kimi_k2": self.tool_call_id_type = "kimi_k2" @@ -266,7 +283,6 @@ def _init_llm(self, chat_template: Optional[str] = None): self.perf_metrics = deque(maxlen=max_perf_metrics) self.perf_metrics_lock = asyncio.Lock() - async def await_disconnected(self, raw_request: Request, promise): if raw_request is None: return @@ -275,7 +291,8 @@ async def await_disconnected(self, raw_request: Request, promise): if not promise.finished: promise.abort() logger.info( - f"{raw_request.client} is disconnected, abort {promise.request_id}") + f"{raw_request.client} is disconnected, abort {promise.request_id}" + ) @property def postproc_worker_enabled(self) -> bool: @@ -323,23 +340,36 @@ def _check_health(self) -> bool: def register_routes(self): self.app.add_api_route("/health", self.health, methods=["GET"]) - self.app.add_api_route("/health_generate", self.health_generate, methods=["GET"]) + self.app.add_api_route("/health_generate", + self.health_generate, + methods=["GET"]) self.app.add_api_route("/version", self.version, methods=["GET"]) self.app.add_api_route("/v1/models", self.get_model, methods=["GET"]) # TODO: the metrics endpoint only reports iteration stats, not the runtime stats for now - self.app.add_api_route("/metrics", self.get_iteration_stats, methods=["GET"]) - self.app.add_api_route("/perf_metrics", self.get_perf_metrics, methods=["GET"]) - self.app.add_api_route("/steady_clock_offset", self.get_steady_clock_offset, methods=["GET"]) + self.app.add_api_route("/metrics", + self.get_iteration_stats, + methods=["GET"]) + self.app.add_api_route("/perf_metrics", + self.get_perf_metrics, + methods=["GET"]) + self.app.add_api_route("/steady_clock_offset", + self.get_steady_clock_offset, + methods=["GET"]) # Called by the disagg server to set the disagg_server_steady_clock_offset - self.app.add_api_route("/steady_clock_offset", self.set_steady_clock_offset, methods=["POST"]) + self.app.add_api_route("/steady_clock_offset", + self.set_steady_clock_offset, + methods=["POST"]) # TODO: workaround before ETCD support - self.app.add_api_route("/kv_cache_events", self.get_kv_cache_events, methods=["POST"]) + self.app.add_api_route("/kv_cache_events", + self.get_kv_cache_events, + methods=["POST"]) self.app.add_api_route("/v1/completions", self.openai_completion, methods=["POST"]) - self.app.add_api_route("/v1/chat/completions", - self.openai_chat if not self.use_harmony else self.chat_harmony, - methods=["POST"]) + self.app.add_api_route( + "/v1/chat/completions", + self.openai_chat if not self.use_harmony else self.chat_harmony, + methods=["POST"]) self.app.add_api_route("/v1/responses", self.openai_responses, methods=["POST"]) @@ -352,17 +382,17 @@ def register_routes(self): # RL-only endpoints self.app.add_api_route("/release_memory", - self.release_memory, - methods=["POST"]) + self.release_memory, + methods=["POST"]) self.app.add_api_route("/resume_memory", - self.resume_memory, - methods=["POST"]) + self.resume_memory, + methods=["POST"]) self.app.add_api_route("/update_weights", - self.update_weights, - methods=["POST"]) + self.update_weights, + methods=["POST"]) self.app.add_api_route("/server_info", - self.get_server_info, - methods=["GET"]) + self.get_server_info, + methods=["GET"]) if self.generator.args.return_perf_metrics: # register /prometheus/metrics self.mount_metrics() @@ -380,14 +410,13 @@ def mount_metrics(self): Instrumentator( should_group_status_codes=False, should_respect_env_var=True, - excluded_handlers=[ - ".*" - ], + excluded_handlers=[".*"], registry=registry, ).add().instrument(self.app).expose(self.app) metrics_app = make_asgi_app(registry=registry) metrics_route = Mount("/prometheus/metrics", metrics_app) - metrics_route.path_regex = re.compile("^/prometheus/metrics(?P.*)$") + metrics_route.path_regex = re.compile( + "^/prometheus/metrics(?P.*)$") self.app.routes.append(metrics_route) def register_mm_encoder_routes(self): @@ -395,20 +424,22 @@ def register_mm_encoder_routes(self): self.app.add_api_route("/version", self.version, methods=["GET"]) self.app.add_api_route("/v1/models", self.get_model, methods=["GET"]) # TODO: the metrics endpoint only reports iteration stats, not the runtime stats for now - self.app.add_api_route("/metrics", self.get_iteration_stats, methods=["GET"]) + self.app.add_api_route("/metrics", + self.get_iteration_stats, + methods=["GET"]) self.app.add_api_route("/v1/chat/completions", self.openai_mm_encoder, methods=["POST"]) # RL-only endpoints self.app.add_api_route("/release_memory", - self.release_memory, - methods=["POST"]) + self.release_memory, + methods=["POST"]) self.app.add_api_route("/resume_memory", - self.resume_memory, - methods=["POST"]) + self.resume_memory, + methods=["POST"]) self.app.add_api_route("/update_weights", - self.update_weights, - methods=["POST"]) + self.update_weights, + methods=["POST"]) def register_visual_gen_routes(self): """Register routes for diffusion model serving.""" @@ -416,7 +447,9 @@ def register_visual_gen_routes(self): self.app.add_api_route("/health", self.health, methods=["GET"]) self.app.add_api_route("/version", self.version, methods=["GET"]) self.app.add_api_route("/v1/models", self.get_model, methods=["GET"]) - self.app.add_api_route("/metrics", self.get_iteration_stats, methods=["GET"]) + self.app.add_api_route("/metrics", + self.get_iteration_stats, + methods=["GET"]) # Image generation endpoints (OpenAI compatible) self.app.add_api_route("/v1/images/generations", @@ -436,9 +469,7 @@ def register_visual_gen_routes(self): self.openai_video_generation_sync, methods=["POST"]) # Video management endpoints - self.app.add_api_route("/v1/videos", - self.list_videos, - methods=["GET"]) + self.app.add_api_route("/v1/videos", self.list_videos, methods=["GET"]) self.app.add_api_route("/v1/videos/{video_id}", self.get_video_metadata, methods=["GET"]) @@ -453,7 +484,11 @@ async def health(self) -> Response: if self._check_health(): return Response(status_code=200) else: - return Response(status_code=503, content="LLM is unavailable. Please check the server logs for more details.") + return Response( + status_code=503, + content= + "LLM is unavailable. Please check the server logs for more details." + ) async def health_generate(self, raw_request: Request) -> Response: """Health check that performs a minimal generation.""" @@ -467,11 +502,14 @@ async def health_generate(self, raw_request: Request) -> Response: try: # Create a minimal chat request health_request = ChatCompletionRequest( - messages=[{"role": "user", "content": "hi"}], # Minimal prompt (often > 1 token after tokenization) + messages=[{ + "role": "user", + "content": "hi" + }], # Minimal prompt (often > 1 token after tokenization) model=self.model, - max_completion_tokens=1, # Request only 1 token out + max_completion_tokens=1, # Request only 1 token out stream=False, - temperature=0.0, # Deterministic output + temperature=0.0, # Deterministic output **extra_args, ) @@ -480,20 +518,26 @@ async def health_generate(self, raw_request: Request) -> Response: # Check if the response indicates success (status code 200) if response.status_code == 200: - return Response(status_code=200, content="Generation health check OK") + return Response(status_code=200, + content="Generation health check OK") else: - logger.error(f"Health generate check failed with status code: {response.status_code}") + logger.error( + f"Health generate check failed with status code: {response.status_code}" + ) try: # Attempt to get body for more details if possible - body = response.body if hasattr(response, 'body') else await response.body() + body = response.body if hasattr( + response, 'body') else await response.body() logger.error(f"Health generate check response body: {body}") except Exception: - pass # Ignore errors trying to get body details - return Response(status_code=500, content="Generation health check failed") + pass # Ignore errors trying to get body details + return Response(status_code=500, + content="Generation health check failed") except Exception as e: logger.error(f"Health generate check encountered exception: {e}") - return Response(status_code=500, content=f"Generation health check failed: {str(e)}") + return Response(status_code=500, + content=f"Generation health check failed: {str(e)}") async def version(self) -> JSONResponse: ver = {"version": VERSION} @@ -509,23 +553,30 @@ async def get_iteration_stats(self) -> JSONResponse: stats.append(stat) return JSONResponse(content=stats) - async def set_steady_clock_offset(self, offset: Annotated[float, Body(embed=True)]) -> Response: + async def set_steady_clock_offset( + self, offset: Annotated[float, Body(embed=True)]) -> Response: self.disagg_server_steady_clock_offset = offset - logger.info(f"The steady clock offset between local and disagg server: {offset} second") + logger.info( + f"The steady clock offset between local and disagg server: {offset} second" + ) return Response(status_code=200) async def get_steady_clock_offset(self) -> JSONResponse: receive_ts = get_steady_clock_now_in_seconds() await asyncio.sleep(0.2) transmit_ts = get_steady_clock_now_in_seconds() - return JSONResponse(content={"receive_ts": receive_ts, "transmit_ts": transmit_ts}) + return JSONResponse(content={ + "receive_ts": receive_ts, + "transmit_ts": transmit_ts + }) async def get_perf_metrics(self) -> JSONResponse: if self.perf_metrics is None: return JSONResponse(content=[]) async with self.perf_metrics_lock: perf_metrics = self.perf_metrics - self.perf_metrics = deque(maxlen=self.generator.args.perf_metrics_max_requests) + self.perf_metrics = deque( + maxlen=self.generator.args.perf_metrics_max_requests) for metrics_dict in perf_metrics: metrics = metrics_dict["perf_metrics"] timing_metrics = metrics.timing_metrics @@ -539,35 +590,55 @@ async def get_perf_metrics(self) -> JSONResponse: server_arrival_time = metrics_dict.pop("server_arrival_time", None) if server_arrival_time is not None: server_arrival_time += self.disagg_server_steady_clock_offset - server_first_token_time = metrics_dict.pop("server_first_token_time", None) + server_first_token_time = metrics_dict.pop( + "server_first_token_time", None) if server_first_token_time is not None: server_first_token_time += self.disagg_server_steady_clock_offset metrics_json["timing_metrics"] = { - "server_arrival_time": server_arrival_time, - "arrival_time": timing_metrics.arrival_time.total_seconds() + self.disagg_server_steady_clock_offset, - "first_scheduled_time": timing_metrics.first_scheduled_time.total_seconds() + self.disagg_server_steady_clock_offset, - "first_token_time": timing_metrics.first_token_time.total_seconds() + self.disagg_server_steady_clock_offset, - "server_first_token_time": server_first_token_time, - "last_token_time": timing_metrics.last_token_time.total_seconds() + self.disagg_server_steady_clock_offset, + "server_arrival_time": + server_arrival_time, + "arrival_time": + timing_metrics.arrival_time.total_seconds() + + self.disagg_server_steady_clock_offset, + "first_scheduled_time": + timing_metrics.first_scheduled_time.total_seconds() + + self.disagg_server_steady_clock_offset, + "first_token_time": + timing_metrics.first_token_time.total_seconds() + + self.disagg_server_steady_clock_offset, + "server_first_token_time": + server_first_token_time, + "last_token_time": + timing_metrics.last_token_time.total_seconds() + + self.disagg_server_steady_clock_offset, } metrics_json["kv_cache_metrics"] = { - "num_total_allocated_blocks": kv_cache_metrics.num_total_allocated_blocks, - "num_new_allocated_blocks": kv_cache_metrics.num_new_allocated_blocks, + "num_total_allocated_blocks": + kv_cache_metrics.num_total_allocated_blocks, + "num_new_allocated_blocks": + kv_cache_metrics.num_new_allocated_blocks, "num_reused_blocks": kv_cache_metrics.num_reused_blocks, "num_missed_blocks": kv_cache_metrics.num_missed_blocks, } if timing_metrics.kv_cache_size > 0: metrics_json["timing_metrics"].update({ # TODO: move to kv_cache_metrics - "kv_cache_size": timing_metrics.kv_cache_size, - "kv_cache_transfer_start": timing_metrics.kv_cache_transfer_start.total_seconds() + self.disagg_server_steady_clock_offset, - "kv_cache_transfer_end": timing_metrics.kv_cache_transfer_end.total_seconds() + self.disagg_server_steady_clock_offset, + "kv_cache_size": + timing_metrics.kv_cache_size, + "kv_cache_transfer_start": + timing_metrics.kv_cache_transfer_start.total_seconds() + + self.disagg_server_steady_clock_offset, + "kv_cache_transfer_end": + timing_metrics.kv_cache_transfer_end.total_seconds() + + self.disagg_server_steady_clock_offset, }) if speculative_decoding.total_draft_tokens > 0: metrics_json["speculative_decoding"] = { "acceptance_rate": speculative_decoding.acceptance_rate, - "total_accepted_draft_tokens": speculative_decoding.total_accepted_draft_tokens, - "total_draft_tokens": speculative_decoding.total_draft_tokens, + "total_accepted_draft_tokens": + speculative_decoding.total_accepted_draft_tokens, + "total_draft_tokens": + speculative_decoding.total_draft_tokens, } metrics_dict["perf_metrics"] = metrics_json return JSONResponse(content=list(perf_metrics)) @@ -598,12 +669,18 @@ async def _extract_metrics(self, res: RequestOutput, raw_request: Request): "perf_metrics": res.outputs[0].request_perf_metrics } if raw_request: - item["server_arrival_time"] = getattr(raw_request.state, "server_arrival_time", None) - if not getattr(raw_request.state, "server_first_token_time", None): - raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() - item["server_first_token_time"] = raw_request.state.server_first_token_time + item["server_arrival_time"] = getattr(raw_request.state, + "server_arrival_time", + None) + if not getattr(raw_request.state, "server_first_token_time", + None): + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds( + ) + item[ + "server_first_token_time"] = raw_request.state.server_first_token_time if output.disaggregated_params: - item["ctx_request_id"] = output.disaggregated_params.ctx_request_id + item[ + "ctx_request_id"] = output.disaggregated_params.ctx_request_id # Request-level time breakdown (on GenerationResult/RequestOutput, not CompletionOutput) if getattr(res, 'time_breakdown_metrics', None) is not None: item["time_breakdown_metrics"] = res.time_breakdown_metrics @@ -611,11 +688,13 @@ async def _extract_metrics(self, res: RequestOutput, raw_request: Request): async with self.perf_metrics_lock: self.perf_metrics.append(item) - async def _create_chat_response(self, - promise: RequestOutput, - postproc_params: PostprocParams, - raw_request: Request, - disaggregated_params: Optional[LlmDisaggregatedParams] = None) -> ChatCompletionResponse: + async def _create_chat_response( + self, + promise: RequestOutput, + postproc_params: PostprocParams, + raw_request: Request, + disaggregated_params: Optional[LlmDisaggregatedParams] = None + ) -> ChatCompletionResponse: await promise.aresult() if self.postproc_worker_enabled: chat_response = promise.outputs[0]._postprocess_result @@ -623,9 +702,11 @@ async def _create_chat_response(self, post_processor, args = postproc_params.post_processor, postproc_params.postproc_args chat_response = post_processor(promise, args) - if disaggregated_params is not None and chat_response.choices[0].disaggregated_params is None: - raise ValueError(f"disaggregated_params is not set in the response for request" - f" {disaggregated_params.disagg_request_id}") + if disaggregated_params is not None and chat_response.choices[ + 0].disaggregated_params is None: + raise ValueError( + f"disaggregated_params is not set in the response for request" + f" {disaggregated_params.disagg_request_id}") await self._extract_metrics(promise, raw_request) return chat_response @@ -658,7 +739,8 @@ async def _iteration_stats_collector_loop(self): # Since metrics are gauges (point-in-time values), only the most recent stat matters try: latest_stat = None - async for llm_stat in self.generator.get_stats_async(timeout=0.5): + async for llm_stat in self.generator.get_stats_async( + timeout=0.5): latest_stat = llm_stat # Keep only the latest # Log only the most recent iteration stats to Prometheus @@ -666,14 +748,16 @@ async def _iteration_stats_collector_loop(self): self.metrics_collector.log_iteration_stats(latest_stat) except Exception as e: # Log errors but continue collecting stats - logger.error(f"Error collecting iteration stats: {e}", exc_info=True) + logger.error(f"Error collecting iteration stats: {e}", + exc_info=True) # Brief sleep to avoid tight loop on persistent errors await asyncio.sleep(0.1) except asyncio.CancelledError: logger.info("Iteration stats collector loop cancelled") raise - async def openai_chat(self, request: ChatCompletionRequest, raw_request: Request) -> Response: + async def openai_chat(self, request: ChatCompletionRequest, + raw_request: Request) -> Response: def get_role() -> str: if request.add_generation_prompt: @@ -683,19 +767,25 @@ def get_role() -> str: return role async def chat_stream_generator( - promise: RequestOutput, postproc_params: PostprocParams) -> AsyncGenerator[str, None]: + promise: RequestOutput, + postproc_params: PostprocParams) -> AsyncGenerator[str, None]: try: if not self.postproc_worker_enabled: post_processor, args = postproc_params.post_processor, postproc_params.postproc_args first_response = await anext(promise) - raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() - pp_results = first_response.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(first_response, args) + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds( + ) + pp_results = first_response.outputs[ + 0]._postprocess_result if self.postproc_worker_enabled else post_processor( + first_response, args) for pp_res in pp_results: yield pp_res - # Making sure we can handling the situation where there is only one response + # Making sure we can handle the situation where there is only one response res = first_response async for res in promise: - pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) + pp_results = res.outputs[ + 0]._postprocess_result if self.postproc_worker_enabled else post_processor( + res, args) for pp_res in pp_results: yield pp_res yield "data: [DONE]\n\n" @@ -714,13 +804,17 @@ async def chat_stream_generator( # expanded into an embedding bias tensor in the sampler. sampling_params = request.to_sampling_params( vocab_size=self.tokenizer.tokenizer.vocab_size, - gather_generation_logits=self.generator.args.gather_generation_logits, + gather_generation_logits=self.generator.args. + gather_generation_logits, reasoning_parser=self.generator.args.reasoning_parser, backend=self.generator.args.backend) postproc_args = ChatPostprocArgs.from_request(request) - disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params) + disaggregated_params = to_llm_disaggregated_params( + request.disaggregated_params) - conversation, mm_coroutines, mm_placeholder_counts = parse_chat_messages_coroutines(request.messages, self.model_config, self.multimodal_server_config) + conversation, mm_coroutines, mm_placeholder_counts = parse_chat_messages_coroutines( + request.messages, self.model_config, + self.multimodal_server_config) if request.prompt_token_ids is not None: prompt = request.prompt_token_ids @@ -745,7 +839,9 @@ async def chat_stream_generator( if mm_embeddings: prompt["multi_modal_embeddings"] = mm_embeddings if mm_data and mm_embeddings: - raise ValueError("Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported.") + raise ValueError( + "Passing 'multi_modal_data' and 'multi_modal_embeddings' at the same time is not supported." + ) postproc_args.reasoning_parser = self.generator.args.reasoning_parser postproc_args.tool_parser = self.tool_parser @@ -759,7 +855,8 @@ async def chat_stream_generator( postproc_args=postproc_args, ) - trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers)) + trace_headers = (None if raw_request is None else + tracing.extract_trace_headers(raw_request.headers)) generate_inputs = prompt preprocess_fn = getattr(self.generator, "preprocess", None) @@ -771,7 +868,8 @@ async def chat_stream_generator( promise = self.generator.generate_async( inputs=generate_inputs, sampling_params=sampling_params, - _postproc_params=postproc_params if self.postproc_worker_enabled else None, + _postproc_params=postproc_params + if self.postproc_worker_enabled else None, streaming=request.stream, lora_request=request.lora_request, disaggregated_params=disaggregated_params, @@ -784,11 +882,13 @@ async def chat_stream_generator( postproc_args.num_prompt_tokens = len(promise.prompt_token_ids) if request.stream: - response_generator = chat_stream_generator(promise, postproc_params) + response_generator = chat_stream_generator( + promise, postproc_params) return StreamingResponse(content=response_generator, media_type="text/event-stream") else: - response = await self._create_chat_response(promise, postproc_params, raw_request, disaggregated_params) + response = await self._create_chat_response( + promise, postproc_params, raw_request, disaggregated_params) return JSONResponse(content=response.model_dump()) except CppExecutorError: logger.error(traceback.format_exc()) @@ -798,16 +898,15 @@ async def chat_stream_generator( logger.error(traceback.format_exc()) return self.create_error_response(str(e)) - async def openai_mm_encoder(self, request: ChatCompletionRequest, raw_request: Request) -> Response: + async def openai_mm_encoder(self, request: ChatCompletionRequest, + raw_request: Request) -> Response: async def create_mm_embedding_response(promise: RequestOutput): await promise.aresult() # TODO: Replace mm_embedding_handles with a dedicated OpenAIBaseModel(JSON-safe), when enable multimodal disagg E2E mm_embedding_handles = ( promise.disaggregated_params.multimodal_embedding_handles - if promise.disaggregated_params - else None - ) + if promise.disaggregated_params else None) if not mm_embedding_handles: return self.create_error_response( message="Multimodal embedding handle missing in response", @@ -818,12 +917,11 @@ async def create_mm_embedding_response(promise: RequestOutput): message="Multimodal embedding handle missing tensor_size", err_type="InternalServerError", status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - mm_embedding_handle = ( - mm_embedding_handles[0] - if len(mm_embedding_handles) == 1 - else mm_embedding_handles - ) - num_tokens = sum(int(h["tensor_size"][0]) for h in mm_embedding_handles) + mm_embedding_handle = (mm_embedding_handles[0] + if len(mm_embedding_handles) == 1 else + mm_embedding_handles) + num_tokens = sum( + int(h["tensor_size"][0]) for h in mm_embedding_handles) return ChatCompletionResponse( id=str(promise.request_id), model=self.model, @@ -848,7 +946,8 @@ async def create_mm_embedding_response(promise: RequestOutput): tool.model_dump() for tool in request.tools ] - conversation, mm_coroutines, mm_placeholder_counts = parse_chat_messages_coroutines(request.messages, self.model_config) + conversation, mm_coroutines, mm_placeholder_counts = parse_chat_messages_coroutines( + request.messages, self.model_config) if request.prompt_token_ids is not None: prompt = request.prompt_token_ids @@ -873,9 +972,7 @@ async def create_mm_embedding_response(promise: RequestOutput): if mm_data is not None: prompt["multi_modal_data"] = mm_data - promise = self.generator.generate_async( - inputs=prompt, - ) + promise = self.generator.generate_async(inputs=prompt, ) asyncio.create_task(self.await_disconnected(raw_request, promise)) response = await create_mm_embedding_response(promise) @@ -889,10 +986,13 @@ async def create_mm_embedding_response(promise: RequestOutput): logger.error(traceback.format_exc()) return self.create_error_response(str(e)) - async def openai_completion(self, request: CompletionRequest, raw_request: Request) -> Response: + async def openai_completion(self, request: CompletionRequest, + raw_request: Request) -> Response: - async def completion_response(promise: RequestOutput, - postproc_params: Optional[PostprocParams]) -> CompletionResponse: + async def completion_response( + promise: RequestOutput, + postproc_params: Optional[PostprocParams] + ) -> CompletionResponse: response = await promise if not self.postproc_worker_enabled: post_processor, args = postproc_params.post_processor, postproc_params.postproc_args @@ -905,7 +1005,8 @@ async def completion_response(promise: RequestOutput, await self._extract_metrics(response, raw_request) return pp_result - def merge_completion_responses(responses: List[CompletionResponse]) -> CompletionResponse: + def merge_completion_responses( + responses: List[CompletionResponse]) -> CompletionResponse: all_choices: List[CompletionResponseChoice] = [] all_prompt_token_ids: List[List[int]] = [] num_prompt_tokens = num_gen_tokens = num_cached_tokens = 0 @@ -924,8 +1025,7 @@ def merge_completion_responses(responses: List[CompletionResponse]) -> Completio completion_tokens=num_gen_tokens, total_tokens=num_gen_tokens + num_prompt_tokens, prompt_tokens_details=PromptTokensDetails( - cached_tokens=num_cached_tokens, - ), + cached_tokens=num_cached_tokens, ), ) merged_rsp = CompletionResponse( model=self.model, @@ -935,7 +1035,8 @@ def merge_completion_responses(responses: List[CompletionResponse]) -> Completio ) return merged_rsp - async def completion_generator(promise: RequestOutput, params: Optional[PostprocParams]): + async def completion_generator(promise: RequestOutput, + params: Optional[PostprocParams]): try: async for output in promise: if not self.postproc_worker_enabled: @@ -950,7 +1051,6 @@ async def completion_generator(promise: RequestOutput, params: Optional[Postproc logger.error(traceback.format_exc()) raise - async def merge_generators(generators: List[AsyncIterator[Any]]): result_queue = asyncio.Queue() finished = [False] * len(generators) @@ -961,7 +1061,8 @@ async def producer(generator: AsyncIterator[Any], idx: int): finished[idx] = True tasks = [ - asyncio.create_task(producer(generator, idx)) for idx, generator in enumerate(generators) + asyncio.create_task(producer(generator, idx)) + for idx, generator in enumerate(generators) ] while not all(finished) or not result_queue.empty(): @@ -971,7 +1072,8 @@ async def producer(generator: AsyncIterator[Any], idx: int): async def generator_wrapper(generator: AsyncIterator[Any]): first_response = await anext(generator) - raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds() + raw_request.state.server_first_token_time = get_steady_clock_now_in_seconds( + ) yield first_response async for output in generator: yield output @@ -990,12 +1092,14 @@ async def generator_wrapper(generator: AsyncIterator[Any]): # expanded into an embedding bias tensor in the sampler. sampling_params = request.to_sampling_params( vocab_size=self.tokenizer.tokenizer.vocab_size, - gather_generation_logits=self.generator.args.gather_generation_logits, + gather_generation_logits=self.generator.args. + gather_generation_logits, backend=self.generator.args.backend) # TODO: better way to enable metrics if len(os.getenv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH", "")) > 0: sampling_params.return_perf_metrics = True - disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params) + disaggregated_params = to_llm_disaggregated_params( + request.disaggregated_params) for idx, prompt in enumerate(prompts): postproc_args = CompletionPostprocArgs.from_request(request) postproc_args.prompt_idx = idx @@ -1006,12 +1110,19 @@ async def generator_wrapper(generator: AsyncIterator[Any]): if request.stream else completion_response_post_processor, postproc_args=postproc_args, ) - trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers)) + trace_headers = (None if raw_request is None else + tracing.extract_trace_headers( + raw_request.headers)) prompt = prompt_inputs(prompt) if prompt.get("prompt") is not None: - prompt_token_ids, extra_processed_inputs = await asyncio.to_thread(self.generator.input_processor, prompt, sampling_params) - tokens_prompt = TokensPrompt(prompt_token_ids=prompt_token_ids, query_token_ids=extra_processed_inputs.get("query_token_ids") if extra_processed_inputs is not None else None) + prompt_token_ids, extra_processed_inputs = await asyncio.to_thread( + self.generator.input_processor, prompt, sampling_params) + tokens_prompt = TokensPrompt( + prompt_token_ids=prompt_token_ids, + query_token_ids=extra_processed_inputs.get( + "query_token_ids") + if extra_processed_inputs is not None else None) else: tokens_prompt = prompt @@ -1022,25 +1133,34 @@ async def generator_wrapper(generator: AsyncIterator[Any]): streaming=request.stream, lora_request=request.lora_request, disaggregated_params=disaggregated_params, - trace_headers=trace_headers - ) - asyncio.create_task(self.await_disconnected(raw_request, promise)) + trace_headers=trace_headers) + asyncio.create_task( + self.await_disconnected(raw_request, promise)) if not self.postproc_worker_enabled: postproc_args.tokenizer = self.tokenizer - postproc_args.num_prompt_tokens = len(promise.prompt_token_ids) + postproc_args.num_prompt_tokens = len( + promise.prompt_token_ids) promises.append(promise) - postproc_params_collection.append(None if self.postproc_worker_enabled else postproc_params) + postproc_params_collection.append( + None if self.postproc_worker_enabled else postproc_params) if request.stream: - generators = [completion_generator(promise, params) - for promise, params in zip(promises, postproc_params_collection)] - response_generator = merge_generators(generators) if len(promises) > 1 else generators[0] - return StreamingResponse(content=generator_wrapper(response_generator), - media_type="text/event-stream") + generators = [ + completion_generator(promise, params) for promise, params in + zip(promises, postproc_params_collection) + ] + response_generator = merge_generators(generators) if len( + promises) > 1 else generators[0] + return StreamingResponse( + content=generator_wrapper(response_generator), + media_type="text/event-stream") else: - rsps = await asyncio.gather(*[completion_response(promise, params) - for promise, params in zip(promises, postproc_params_collection)]) - response = merge_completion_responses(rsps) if len(rsps) > 1 else rsps[0] + rsps = await asyncio.gather(*[ + completion_response(promise, params) for promise, params in + zip(promises, postproc_params_collection) + ]) + response = merge_completion_responses(rsps) if len( + rsps) > 1 else rsps[0] return JSONResponse(content=response.model_dump()) except CppExecutorError: logger.error(traceback.format_exc()) @@ -1050,13 +1170,15 @@ async def generator_wrapper(generator: AsyncIterator[Any]): logger.error(traceback.format_exc()) return self.create_error_response(str(e)) - async def chat_harmony(self, request: ChatCompletionRequest, raw_request: Request) -> Response: + async def chat_harmony(self, request: ChatCompletionRequest, + raw_request: Request) -> Response: """ Chat Completion API with harmony format support. Supports both streaming and non-streaming modes. """ - async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams): + async def create_streaming_generator(promise: RequestOutput, + postproc_params: PostprocParams): async for res in promise: if not self.postproc_worker_enabled: post_processor, args = postproc_params.post_processor, postproc_params.postproc_args @@ -1079,7 +1201,8 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po tools_dict = [tool.model_dump() for tool in request.tools] # Reasoning effort precedence: request.reasoning_effort > system message parsing > serving default - reasoning_effort = maybe_transform_reasoning_effort(request.reasoning_effort) + reasoning_effort = maybe_transform_reasoning_effort( + request.reasoning_effort) # Get tool_choice from request tool_choice = getattr(request, 'tool_choice', None) @@ -1088,8 +1211,7 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po request.messages, tools_dict, reasoning_effort=reasoning_effort, - tool_choice=tool_choice - ) + tool_choice=tool_choice) except Exception as e: logger.error(f"messages_dict: {request.messages}") logger.error(f"tools_dict: {tools_dict}") @@ -1107,8 +1229,10 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po vocab_size=self.tokenizer.tokenizer.vocab_size, reasoning_parser="gpt_oss") sampling_params.detokenize = False # Harmony adapter handles detokenization - disaggregated_params = to_llm_disaggregated_params(request.disaggregated_params) - trace_headers = (None if raw_request is None else tracing.extract_trace_headers(raw_request.headers)) + disaggregated_params = to_llm_disaggregated_params( + request.disaggregated_params) + trace_headers = (None if raw_request is None else + tracing.extract_trace_headers(raw_request.headers)) postproc_args = ChatCompletionPostprocArgs.from_request(request) postproc_params = PostprocParams( @@ -1121,7 +1245,8 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po promise = self.generator.generate_async( inputs=harmony_tokens, sampling_params=sampling_params, - _postproc_params=postproc_params if self.postproc_worker_enabled else None, + _postproc_params=postproc_params + if self.postproc_worker_enabled else None, streaming=bool(request.stream), lora_request=request.lora_request, disaggregated_params=disaggregated_params, @@ -1137,10 +1262,9 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po # Handle streaming if request.stream: - return StreamingResponse( - content=create_streaming_generator(promise, postproc_params), - media_type="text/event-stream" - ) + return StreamingResponse(content=create_streaming_generator( + promise, postproc_params), + media_type="text/event-stream") else: response = await self._create_chat_response( promise, postproc_params, raw_request, disaggregated_params) @@ -1149,11 +1273,15 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po except Exception as e: logger.error("Error in harmony chat completion: %s", e) logger.debug("Error details: %s", traceback.format_exc()) - return self.create_error_response(message=str(e), err_type="internal_error") + return self.create_error_response(message=str(e), + err_type="internal_error") + + async def openai_responses(self, request: ResponsesRequest, + raw_request: Request) -> Response: - async def openai_responses(self, request: ResponsesRequest, raw_request: Request) -> Response: async def create_response( - promise: RequestOutput, postproc_params: PostprocParams) -> ResponsesResponse: + promise: RequestOutput, + postproc_params: PostprocParams) -> ResponsesResponse: await promise.aresult() if self.postproc_worker_enabled: response = promise.outputs[0]._postprocess_result @@ -1174,7 +1302,8 @@ async def create_response( return response - async def create_streaming_generator(promise: RequestOutput, postproc_params: PostprocParams): + async def create_streaming_generator(promise: RequestOutput, + postproc_params: PostprocParams): post_processor, args = postproc_params.post_processor, postproc_params.postproc_args streaming_processor = args.streaming_processor initial_responses = streaming_processor.get_initial_responses() @@ -1182,13 +1311,17 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po yield initial_response async for res in promise: - pp_results = res.outputs[0]._postprocess_result if self.postproc_worker_enabled else post_processor(res, args) + pp_results = res.outputs[ + 0]._postprocess_result if self.postproc_worker_enabled else post_processor( + res, args) for pp_res in pp_results: yield pp_res try: if request.background: - logger.warning("Request.background is not supported yet, will fallback to foreground processing.") + logger.warning( + "Request.background is not supported yet, will fallback to foreground processing." + ) # Get prev response prev_response = None @@ -1196,12 +1329,16 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po prev_response_id = request.previous_response_id if prev_response_id is not None: if not prev_response_id.startswith("resp_"): - return self._create_invalid_response_id_error(prev_response_id) + return self._create_invalid_response_id_error( + prev_response_id) - prev_response = await self.conversation_store.load_response(prev_response_id) + prev_response = await self.conversation_store.load_response( + prev_response_id) if prev_response is None: - logger.debug(f"response_id {prev_response_id} not found") - return self._create_response_id_not_found_error(prev_response_id) + logger.debug( + f"response_id {prev_response_id} not found") + return self._create_response_id_not_found_error( + prev_response_id) input_tokens, sampling_params = await responses_api_request_preprocess( request=request, @@ -1210,9 +1347,11 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po enable_store=self.enable_store and request.store, use_harmony=self.use_harmony, tokenizer=self.tokenizer if not self.use_harmony else None, - model_config=self.model_config if not self.use_harmony else None, + model_config=self.model_config + if not self.use_harmony else None, processor=self.processor if not self.use_harmony else None, - reasoning_parser=self.generator.args.reasoning_parser if not self.use_harmony else "gpt_oss", + reasoning_parser=self.generator.args.reasoning_parser + if not self.use_harmony else "gpt_oss", ) streaming_processor = None @@ -1247,19 +1386,20 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po inputs=input_tokens, sampling_params=sampling_params, streaming=request.stream, - _postproc_params=postproc_params if self.postproc_worker_enabled else None, + _postproc_params=postproc_params + if self.postproc_worker_enabled else None, ) if self.postproc_worker_enabled and request.store: - logger.warning("Postproc workers are enabled, request will not be stored!") + logger.warning( + "Postproc workers are enabled, request will not be stored!") asyncio.create_task(self.await_disconnected(raw_request, promise)) if request.stream: - return StreamingResponse( - content=create_streaming_generator(promise, postproc_params), - media_type="text/event-stream" - ) + return StreamingResponse(content=create_streaming_generator( + promise, postproc_params), + media_type="text/event-stream") else: response = await create_response(promise, postproc_params) return JSONResponse(content=response.model_dump()) @@ -1273,10 +1413,13 @@ async def create_streaming_generator(promise: RequestOutput, postproc_params: Po return JSONResponse(content={"detail": "None"}) - async def openai_responses_get_response(self, response_id: str) -> JSONResponse: + async def openai_responses_get_response(self, + response_id: str) -> JSONResponse: logger.info(f"Getting response: {response_id}") if not self.enable_store: - return self.create_error_response(message="Response storage is disabled", err_type="InvalidRequestError") + return self.create_error_response( + message="Response storage is disabled", + err_type="InvalidRequestError") if not response_id.startswith("resp_"): return self._create_invalid_response_id_error(response_id) @@ -1287,10 +1430,13 @@ async def openai_responses_get_response(self, response_id: str) -> JSONResponse: return JSONResponse(content=response.model_dump()) - async def openai_responses_delete_response(self, response_id: str) -> JSONResponse: + async def openai_responses_delete_response( + self, response_id: str) -> JSONResponse: logger.info(f"Deleting response: {response_id}") if not self.enable_store: - return self.create_error_response(message="Response storage is disabled", err_type="InvalidRequestError") + return self.create_error_response( + message="Response storage is disabled", + err_type="InvalidRequestError") if not response_id.startswith("resp_"): return self._create_invalid_response_id_error(response_id) @@ -1305,29 +1451,38 @@ async def openai_responses_delete_response(self, response_id: str) -> JSONRespon "deleted": True }) - async def release_memory(self, request: MemoryUpdateRequest) -> JSONResponse: - assert isinstance(self.generator, AsyncLLM), "/release_memory endpoint is only supported with AsyncLLM()" - await self.generator.collective_rpc('sleep', args=(request.tags,)) + async def release_memory(self, + request: MemoryUpdateRequest) -> JSONResponse: + assert isinstance( + self.generator, AsyncLLM + ), "/release_memory endpoint is only supported with AsyncLLM()" + await self.generator.collective_rpc('sleep', args=(request.tags, )) return JSONResponse(content={"status": "success"}) async def resume_memory(self, request: MemoryUpdateRequest) -> JSONResponse: - assert isinstance(self.generator, AsyncLLM), "/resume_memory endpoint is only supported with AsyncLLM()" - await self.generator.collective_rpc('wakeup', args=(request.tags,)) + assert isinstance( + self.generator, AsyncLLM + ), "/resume_memory endpoint is only supported with AsyncLLM()" + await self.generator.collective_rpc('wakeup', args=(request.tags, )) return JSONResponse(content={"status": "success"}) - async def update_weights(self, request: UpdateWeightsRequest) -> JSONResponse: - assert isinstance(self.generator, AsyncLLM), "/update_weights endpoint is only supported with AsyncLLM()" - await self.generator.collective_rpc('update_weights', args=(request.weights,)) + async def update_weights(self, + request: UpdateWeightsRequest) -> JSONResponse: + assert isinstance( + self.generator, AsyncLLM + ), "/update_weights endpoint is only supported with AsyncLLM()" + await self.generator.collective_rpc('update_weights', + args=(request.weights, )) return JSONResponse(content={"status": "success"}) async def get_server_info(self) -> JSONResponse: - return JSONResponse(content={"disaggregated_params": self.generator.disaggregated_params}) + return JSONResponse( + content={ + "disaggregated_params": self.generator.disaggregated_params + }) - async def openai_image_generation( - self, - request: ImageGenerationRequest, - raw_request: Request - ) -> Response: + async def openai_image_generation(self, request: ImageGenerationRequest, + raw_request: Request) -> Response: """OpenAI-compatible image generation endpoint. Follows the OpenAI Images API specification for image generation. @@ -1335,10 +1490,17 @@ async def openai_image_generation( try: image_id = f"image_{uuid.uuid4().hex}" params = parse_visual_gen_params(request, image_id) - logger.info(f"Generating image: {image_id} with params: {params} and prompt: {request.prompt}") + logger.info( + f"Generating image: {image_id} with params: {params} and prompt: {request.prompt}" + ) if request.negative_prompt is not None: - inputs = visual_gen_inputs({"prompt": request.prompt, "negative_prompt": request.negative_prompt}) + inputs = visual_gen_inputs({ + "prompt": + request.prompt, + "negative_prompt": + request.negative_prompt + }) else: inputs = visual_gen_inputs(request.prompt) output = self.generator.generate(inputs=inputs, params=params) @@ -1361,10 +1523,11 @@ async def openai_image_generation( if request.response_format == "b64_json": data = [ - ImageObject( - b64_json=base64.b64encode(MediaStorage.convert_image_to_bytes(image)).decode('utf-8'), - revised_prompt=request.prompt - ) for image in output_images + ImageObject(b64_json=base64.b64encode( + MediaStorage.convert_image_to_bytes(image)).decode( + 'utf-8'), + revised_prompt=request.prompt) + for image in output_images ] response = ImageGenerationResponse( @@ -1375,7 +1538,8 @@ async def openai_image_generation( elif request.response_format == "url": # TODO: Support URL mode - return self._create_not_supported_error("URL mode is not supported for image generation") + return self._create_not_supported_error( + "URL mode is not supported for image generation") return JSONResponse(content=response.model_dump()) @@ -1383,12 +1547,8 @@ async def openai_image_generation( logger.error(traceback.format_exc()) return self.create_error_response(str(e)) - - async def openai_image_edit( - self, - request: ImageEditRequest, - raw_request: Request - ) -> Response: + async def openai_image_edit(self, request: ImageEditRequest, + raw_request: Request) -> Response: """OpenAI-compatible image editing endpoint. Follows the OpenAI Images API specification for image editing. @@ -1397,10 +1557,17 @@ async def openai_image_edit( try: image_id = f"image_{uuid.uuid4().hex}" params = parse_visual_gen_params(request, image_id) - logger.info(f"Editing image: {image_id} with params: {params} and prompt: {request.prompt}") + logger.info( + f"Editing image: {image_id} with params: {params} and prompt: {request.prompt}" + ) if request.negative_prompt is not None: - inputs = visual_gen_inputs({"prompt": request.prompt, "negative_prompt": request.negative_prompt}) + inputs = visual_gen_inputs({ + "prompt": + request.prompt, + "negative_prompt": + request.negative_prompt + }) else: inputs = visual_gen_inputs(request.prompt) output = self.generator.generate(inputs=inputs, params=params) @@ -1424,10 +1591,11 @@ async def openai_image_edit( response = ImageGenerationResponse( created=int(time.time()), data=[ - ImageObject( - b64_json=base64.b64encode(MediaStorage.convert_image_to_bytes(image)).decode('utf-8'), - revised_prompt=request.prompt - ) for image in output_images + ImageObject(b64_json=base64.b64encode( + MediaStorage.convert_image_to_bytes(image)).decode( + 'utf-8'), + revised_prompt=request.prompt) + for image in output_images ], size=f"{params.width}x{params.height}", ) @@ -1436,12 +1604,13 @@ async def openai_image_edit( except Exception as e: logger.error(traceback.format_exc()) - return self.create_error_response(message=str(e), err_type="InternalServerError", status_code=HTTPStatus.INTERNAL_SERVER_ERROR) + return self.create_error_response( + message=str(e), + err_type="InternalServerError", + status_code=HTTPStatus.INTERNAL_SERVER_ERROR) - async def openai_video_generation_sync( - self, - raw_request: Request - ) -> Response: + async def openai_video_generation_sync(self, + raw_request: Request) -> Response: """Synchronous video generation endpoint. Waits for video generation to complete before returning. @@ -1460,11 +1629,21 @@ async def openai_video_generation_sync( request.output_format) video_id = f"video_{uuid.uuid4().hex}" - params = parse_visual_gen_params(request, video_id, media_storage_path=str(self.media_storage_path)) - logger.info(f"Generating video: {video_id} with params: {params} and prompt: {request.prompt}") + params = parse_visual_gen_params(request, + video_id, + media_storage_path=str( + self.media_storage_path)) + logger.info( + f"Generating video: {video_id} with params: {params} and prompt: {request.prompt}" + ) if request.negative_prompt is not None: - inputs = visual_gen_inputs({"prompt": request.prompt, "negative_prompt": request.negative_prompt}) + inputs = visual_gen_inputs({ + "prompt": + request.prompt, + "negative_prompt": + request.negative_prompt + }) else: inputs = visual_gen_inputs(request.prompt) output = self.generator.generate(inputs=inputs, params=params) @@ -1560,7 +1739,9 @@ async def _parse_video_generation_request( return VideoGenerationRequest(**data) else: - raise ValueError(f"Unsupported content-type: {content_type}. Use 'application/json' or 'multipart/form-data'") + raise ValueError( + f"Unsupported content-type: {content_type}. Use 'application/json' or 'multipart/form-data'" + ) async def openai_video_generation_async( self, @@ -1581,8 +1762,13 @@ async def openai_video_generation_async( request = await self._parse_video_generation_request(raw_request) video_id = f"video_{uuid.uuid4().hex}" - params = parse_visual_gen_params(request, video_id, media_storage_path=str(self.media_storage_path)) - logger.info(f"Generating video: {video_id} with params: {params} and prompt: {request.prompt}") + params = parse_visual_gen_params(request, + video_id, + media_storage_path=str( + self.media_storage_path)) + logger.info( + f"Generating video: {video_id} with params: {params} and prompt: {request.prompt}" + ) # Start background generation task self.video_gen_tasks[video_id] = asyncio.create_task( @@ -1590,8 +1776,7 @@ async def openai_video_generation_async( video_id=video_id, request=request, params=params, - ) - ) + )) # Return job metadata immediately video_job = VideoJob( @@ -1628,7 +1813,12 @@ async def _generate_video_background( request.output_format) if request.negative_prompt is not None: - inputs = visual_gen_inputs({"prompt": request.prompt, "negative_prompt": request.negative_prompt}) + inputs = visual_gen_inputs({ + "prompt": + request.prompt, + "negative_prompt": + request.negative_prompt + }) else: inputs = visual_gen_inputs(request.prompt) future = self.generator.generate_async(inputs=inputs, params=params) @@ -1668,10 +1858,7 @@ async def _generate_video_background( job.error = str(e) await VIDEO_STORE.upsert(video_id, job) - async def list_videos( - self, - raw_request: Request - ) -> Response: + async def list_videos(self, raw_request: Request) -> Response: """List all generated videos. GET /v1/videos @@ -1682,20 +1869,15 @@ async def list_videos( video_jobs = await VIDEO_STORE.list_values() # Convert to API format - response = VideoJobList( - data=video_jobs, - ) + response = VideoJobList(data=video_jobs, ) return JSONResponse(content=response.model_dump()) except Exception as e: logger.error(traceback.format_exc()) return self.create_error_response(str(e)) - async def get_video_metadata( - self, - video_id: str, - raw_request: Request - ) -> Response: + async def get_video_metadata(self, video_id: str, + raw_request: Request) -> Response: """Get video metadata by ID. GET /v1/videos/{video_id} @@ -1709,16 +1891,14 @@ async def get_video_metadata( return self.create_error_response( f"Video {video_id} not found", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND - ) + status_code=HTTPStatus.NOT_FOUND) # Ensure it's a video if job.object != "video": return self.create_error_response( f"Resource {video_id} is not a video", err_type="BadRequestError", - status_code=HTTPStatus.BAD_REQUEST - ) + status_code=HTTPStatus.BAD_REQUEST) return JSONResponse(content=job.model_dump()) @@ -1726,11 +1906,8 @@ async def get_video_metadata( logger.error(traceback.format_exc()) return self.create_error_response(str(e)) - async def get_video_content( - self, - video_id: str, - raw_request: Request - ) -> Response: + async def get_video_content(self, video_id: str, + raw_request: Request) -> Response: """Download video file by ID. GET /v1/videos/{video_id}/content @@ -1743,23 +1920,20 @@ async def get_video_content( return self.create_error_response( f"Video {video_id} not found", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND - ) + status_code=HTTPStatus.NOT_FOUND) # Ensure it's a video and completed if job.object != "video": return self.create_error_response( f"Resource {video_id} is not a video", err_type="BadRequestError", - status_code=HTTPStatus.BAD_REQUEST - ) + status_code=HTTPStatus.BAD_REQUEST) if job.status != "completed": return self.create_error_response( f"Video {video_id} is not ready (status: {job.status})", err_type="BadRequestError", - status_code=HTTPStatus.BAD_REQUEST - ) + status_code=HTTPStatus.BAD_REQUEST) # Try to use stored output path, otherwise check for both .mp4 and .avi video_path = None @@ -1784,18 +1958,14 @@ async def get_video_content( return self.create_error_response( f"Video {video_id} not found", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND - ) + status_code=HTTPStatus.NOT_FOUND) except Exception as e: logger.error(traceback.format_exc()) return self.create_error_response(str(e)) - async def delete_video( - self, - video_id: str, - raw_request: Request - ) -> Response: + async def delete_video(self, video_id: str, + raw_request: Request) -> Response: """Delete a video by ID. DELETE /v1/videos/{video_id} @@ -1808,16 +1978,14 @@ async def delete_video( return self.create_error_response( f"Video {video_id} not found", err_type="NotFoundError", - status_code=HTTPStatus.NOT_FOUND - ) + status_code=HTTPStatus.NOT_FOUND) # Ensure it's a video if job.object != "video": return self.create_error_response( f"Resource {video_id} is not a video", err_type="BadRequestError", - status_code=HTTPStatus.BAD_REQUEST - ) + status_code=HTTPStatus.BAD_REQUEST) # Delete the video file(s) - check for both .mp4 and .avi video_path = None @@ -1843,7 +2011,10 @@ async def delete_video( logger.error(traceback.format_exc()) return self.create_error_response(str(e)) - async def __call__(self, host, port, sockets: list[socket.socket] | None = None): + async def __call__(self, + host, + port, + sockets: list[socket.socket] | None = None): # Store the binding address for server registration self.binding_addr = f"http://{host}:{port}" self.host = host diff --git a/tensorrt_llm/serve/scripts/time_breakdown/time_breakdown.py b/tensorrt_llm/serve/scripts/time_breakdown/time_breakdown.py index 60aaed35092..19ad6767d6f 100644 --- a/tensorrt_llm/serve/scripts/time_breakdown/time_breakdown.py +++ b/tensorrt_llm/serve/scripts/time_breakdown/time_breakdown.py @@ -390,10 +390,10 @@ def create_timing_diagram(self, # Calculate e2e time for each request def get_e2e_time(data): """Calculate end-to-end time from arrival to last token.""" - # Get arrival time - arrival = data.get('ctx_server_arrival_time', float('nan')) + # Get arrival time (executor arrival first, then server arrival) + arrival = data.get('ctx_arrival_time', float('nan')) if math.isnan(arrival): - arrival = data.get('ctx_arrival_time', float('nan')) + arrival = data.get('ctx_server_arrival_time', float('nan')) if math.isnan(arrival): arrival = data.get('disagg_server_arrival_time', float('nan')) @@ -407,9 +407,9 @@ def get_e2e_time(data): return 0 def get_arrival_time(data): - arrival = data.get('ctx_server_arrival_time', float('nan')) + arrival = data.get('ctx_arrival_time', float('nan')) if math.isnan(arrival): - arrival = data.get('ctx_arrival_time', float('nan')) + arrival = data.get('ctx_server_arrival_time', float('nan')) if math.isnan(arrival): arrival = data.get('disagg_server_arrival_time', float('inf')) return arrival