diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 409c726d10c..49a972fc4e2 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -798,6 +798,47 @@ async def _notify_cond_for_new_request(self): async with self._cond: self._cond.notify_all() + def _handle_failed_request(self, request_id: int): + """Handle a failed request by sending the reply immediately. + + The request is added to failed_request_ids so that the next bookkeeping pass can return it. + """ + request_entry = self.requests[request_id] + request = request_entry.record[-1] + + if self.rank == 0: + warnings.warn( + f"Request {request_id} failed to be added to the engine due to errors. " + f"Prompt Tokens: {len(request.prompt_tokens)} " + f"Tokens to generate: {request.sampling_params.num_tokens_to_generate} " + f"Max sequence length: {self.context.max_sequence_length} " + f"Chunked prefill enabled: {self.enable_chunked_prefill}" + ) + + request.status = Status.FAILED + request.add_event_fail() + self.failed_request_ids.append(request_id) + + # Send the reply immediately, because it may never get a chance to be sent again. + if self.use_coordinator and self.is_mp_coordinator: + payload = msgpack.packb( + [Headers.ENGINE_REPLY.value, [request_entry.record.merge().serialize()]], + use_bin_type=True, + ) + self.socket_for_receiving_requests.send(payload) + elif not self.use_coordinator: + if request.prompt is None: + request.prompt = self.controller.tokenizer.detokenize( + request.prompt_tokens.tolist() + ) + if request.generated_tokens: + request.generated_text = self.controller.tokenizer.detokenize( + request.generated_tokens + ) + else: + request.generated_text = "" + request_entry.future.set_result(request_entry.record) + def has_unfinished_requests(self) -> bool: """Test if context contains unfinished requests.""" return self.context.has_unfinished_requests() or len(self.waiting_request_ids) > 0 @@ -874,16 +915,10 @@ def _add_request( len(request.prompt_tokens) + request.sampling_params.num_tokens_to_generate > self.context.max_sequence_length ) or (request.sampling_params.num_tokens_to_generate < 0): - logging.error( - f"{request_id=} Invalid number of tokens to generate. Prompt len: {len(request.prompt_tokens)}, tokens to generate: {request.sampling_params.num_tokens_to_generate}, max seq len: {self.context.max_sequence_length}." - ) request.status = Status.FAILED request.add_event_error_nontransient(MaxSequenceLengthOverflowError(request_id)) if len(request.prompt_tokens) > self.context.max_tokens and not self.enable_chunked_prefill: - logging.error( - f"{request_id=} Prompt is longer than context.max_tokens. Prompt tokens: {len(request.prompt_tokens)}, context.max_tokens: {self.context.max_tokens}, chunked_prefill: {self.enable_chunked_prefill}" - ) request.status = Status.FAILED request.add_event_error_nontransient(TokenOverflowError(request_id)) @@ -898,14 +933,7 @@ def _add_request( if request.status != Status.FAILED: self.waiting_request_ids.append(request_id) else: - self.failed_request_ids.append(request_id) - if self.rank == 0: - warnings.warn( - f"Request {request_id} failed to be added to the engine due to errors. " - f"Prompt Tokens: {len(request.prompt_tokens)} " - f"Tokens to generate: {request.sampling_params.num_tokens_to_generate} " - f"Max sequence length: {self.context.max_sequence_length} " - ) + self._handle_failed_request(request_id) return self.requests[request_id].future @@ -1616,14 +1644,14 @@ async def async_bookkeep( active_request_ids: list[int] = [] finished_request_records: list[DynamicInferenceRequestRecord] = [] - # Failed requests. + # Failed requests. Status and events were already set in _handle_failed_request; + # here we just clean up the entry and include it in finished_request_records. for failed_request_id in self.failed_request_ids: failed_entry = self.requests.pop(failed_request_id) - failed_request = failed_entry.record[-1] - failed_request.status = Status.FAILED - failed_request.add_event_fail() finished_request_records.append(failed_entry.record) - failed_entry.future.set_result(failed_entry.record) + assert ( + failed_entry.future.done() + ), f"Failed request {failed_request_id} future has not been properly resolved." self.failed_request_ids.clear() range_pop() @@ -1644,17 +1672,20 @@ async def async_bookkeep( range_pop() # Handle necessary ZMQ DP coordinator communication. - if self.use_coordinator and self.is_mp_coordinator and finished_request_records: - range_push("coordinator_communication") - payload = msgpack.packb( - [ - Headers.ENGINE_REPLY.value, - [r.merge().serialize() for r in finished_request_records], - ], - use_bin_type=True, - ) - self.socket_for_receiving_requests.send(payload) - range_pop() + # Failed request replies were already sent in _handle_failed_request, + # so only send completed records here. + if self.use_coordinator and self.is_mp_coordinator: + records_to_send = [ + r for r in finished_request_records if r.requests[-1].status != Status.FAILED + ] + if records_to_send: + range_push("coordinator_communication") + payload = msgpack.packb( + [Headers.ENGINE_REPLY.value, [r.merge().serialize() for r in records_to_send]], + use_bin_type=True, + ) + self.socket_for_receiving_requests.send(payload) + range_pop() # Log KV cache utilization stats to W&B if context_state["kv_stats"] is not None: diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py index 98f5219a3bf..bc0321b9213 100644 --- a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py @@ -8,7 +8,10 @@ import uuid import warnings -from megatron.core.inference.inference_request import unwrap_serialized_tensors +from megatron.core.inference.inference_request import ( + DynamicInferenceEventType, + unwrap_serialized_tensors, +) from megatron.core.inference.sampling_params import SamplingParams from megatron.core.tokenizers.text.parsers import PARSER_MAPPING @@ -369,7 +372,35 @@ async def chat_completions(): f"{time.perf_counter() - start_time:.2f}s" ) - # --- 4. Format OpenAI Response --- + # --- 4. Check for failed requests --- + failed_errors = [] + has_nontransient_error = False + for i, record in enumerate(batch_results): + last_request = record.requests[-1] + if last_request.failed(): + error_events = [ + e + for e in last_request.events + if e.type + in ( + DynamicInferenceEventType.ERROR_NONTRANSIENT, + DynamicInferenceEventType.ERROR_TRANSIENT, + ) + ] + if any( + e.type == DynamicInferenceEventType.ERROR_NONTRANSIENT for e in error_events + ): + has_nontransient_error = True + error_msg = str(error_events[-1].payload) if error_events else "Unknown error" + failed_errors.append(f"Request {i}: {error_msg}") + + if failed_errors: + error_detail = "; ".join(failed_errors) + status = 400 if has_nontransient_error else 500 + logger.error(f"Inference request(s) failed: {error_detail}") + return Response(f"Inference request(s) failed: {error_detail}", status=status) + + # --- 5. Format OpenAI Response --- choices = [] total_completion_tokens = 0 prompt_tokens_counts = [] @@ -379,17 +410,6 @@ async def chat_completions(): result = result_item if isinstance(result_item, dict) else result_item.serialize() result = unwrap_serialized_tensors(result) - if result["status"] == "FAILED": - if result["sampling_params"]["num_tokens_to_generate"] <= 0: - return Response( - f"Request {request_idx} failed due to context length overflow", status=400 - ) - else: - return Response( - f"Request {request_idx} failed due to internal error {result['events']}", - status=500, - ) - prompt_tokens_out = result["prompt_tokens"] # The engine can modify prompt_tokens. text_output = result["generated_text"] prompt_tokens_count = len(prompt_tokens_out) if prompt_tokens_out is not None else 0 diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py index af8ec41aac2..437f5aba5ee 100644 --- a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/completions.py @@ -4,7 +4,10 @@ import logging import time -from megatron.core.inference.inference_request import unwrap_serialized_tensors +from megatron.core.inference.inference_request import ( + DynamicInferenceEventType, + unwrap_serialized_tensors, +) from megatron.core.inference.sampling_params import SamplingParams logger = logging.getLogger(__name__) @@ -125,7 +128,35 @@ async def completions(): f"{time.perf_counter() - start_time:.2f}s" ) - # --- 4. Format Response (matching old_completions.py) --- + # --- 4. Check for failed requests --- + failed_errors = [] + has_nontransient_error = False + for i, record in enumerate(batch_results): + last_request = record.requests[-1] + if last_request.failed(): + error_events = [ + e + for e in last_request.events + if e.type + in ( + DynamicInferenceEventType.ERROR_NONTRANSIENT, + DynamicInferenceEventType.ERROR_TRANSIENT, + ) + ] + if any( + e.type == DynamicInferenceEventType.ERROR_NONTRANSIENT for e in error_events + ): + has_nontransient_error = True + error_msg = str(error_events[-1].payload) if error_events else "Unknown error" + failed_errors.append(f"Request {i}: {error_msg}") + + if failed_errors: + error_detail = "; ".join(failed_errors) + status = 400 if has_nontransient_error else 500 + logger.error(f"Inference request(s) failed: {error_detail}") + return f"Inference request(s) failed: {error_detail}", status + + # --- 5. Format Response (matching old_completions.py) --- choices = [] request_idx = 0