Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 47 additions & 18 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,34 @@ def get_request(self, request_id: int) -> DynamicInferenceRequest:
"""
return self.requests[request_id].record[-1]

def _handle_failed_request(self, request_id: int):
"""Handle a failed request by sending the reply immediately.

The request is added to failed_request_ids so that the next bookkeeping pass can return it.
"""
request_entry = self.requests[request_id]
request = request_entry.record[-1]

if self.rank == 0:
warnings.warn(f"Request {request_id} failed to be added to the engine due to errors.")

request.add_event_fail()
self.failed_request_ids.append(request_id)

# Send the reply immediately, because it may never get a chance to be sent again.
if self.use_coordinator and self.is_mp_coordinator:
payload = msgpack.packb(
[Headers.ENGINE_REPLY.value, [entry.record.serialize()]], use_bin_type=True
)
self.socket_for_receiving_requests.send(payload)
elif not self.use_coordinator:
if request.prompt is None:
request.prompt = self.controller.tokenizer.detokenize(
request.prompt_tokens.tolist()
)
request.generated_text = self.controller.tokenizer.detokenize(request.generated_tokens)
entry.future.set_result(entry.record)

def _add_request(
self, request: DynamicInferenceRequest
) -> asyncio.Future[DynamicInferenceRequest]:
Expand Down Expand Up @@ -863,11 +891,7 @@ def _add_request(
if request.status != Status.FAILED:
self.waiting_request_ids.append(request_id)
else:
self.failed_request_ids.append(request_id)
if self.rank == 0:
warnings.warn(
f"Request {request_id} failed to be added to the engine due to errors."
)
self._handle_failed_request(request_id)

return self.requests[request_id].future

Expand Down Expand Up @@ -1501,14 +1525,14 @@ async def async_bookkeep(
active_request_ids: list[int] = []
finished_request_records: list[DynamicInferenceRequestRecord] = []

# Failed requests.
# Failed requests. Status and events were already set in _handle_failed_request;
# here we just clean up the entry and include it in finished_request_records.
for failed_request_id in self.failed_request_ids:
failed_entry = self.requests.pop(failed_request_id)
failed_request = failed_entry.record[-1]
failed_request.status = Status.FAILED
failed_request.add_event_fail()
finished_request_records.append(failed_entry.record)
failed_entry.future.set_result(failed_entry.record)
assert (
failed_entry.future.done()
), f"Failed request {failed_request_id} future has not been properly resolved."
self.failed_request_ids.clear()
range_pop()

Expand All @@ -1529,14 +1553,19 @@ async def async_bookkeep(
range_pop()

# Handle necessary ZMQ DP coordinator communication.
if self.use_coordinator and self.is_mp_coordinator and finished_request_records:
range_push("coordinator_communication")
payload = msgpack.packb(
[Headers.ENGINE_REPLY.value, [r.serialize() for r in finished_request_records]],
use_bin_type=True,
)
self.socket_for_receiving_requests.send(payload)
range_pop()
# Failed request replies were already sent in _handle_failed_request.
if self.use_coordinator and self.is_mp_coordinator:
records_to_send = [
r for r in finished_request_records if r.requests[-1].status != Status.FAILED
]
if records_to_send:
range_push("coordinator_communication")
payload = msgpack.packb(
[Headers.ENGINE_REPLY.value, [r.serialize() for r in records_to_send]],
use_bin_type=True,
)
self.socket_for_receiving_requests.send(payload)
range_pop()

# Log KV cache utilization stats to W&B
if context_state["kv_stats"] is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import uuid
import warnings

from megatron.core.inference.inference_request import DynamicInferenceEventType
from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.tokenizers.text.parsers import PARSER_MAPPING

Expand Down Expand Up @@ -151,7 +152,29 @@ async def chat_completions():
f"{time.perf_counter() - start_time:.2f}s"
)

# --- 4. Format OpenAI Response ---
# --- 4. Check for failed requests ---
failed_errors = []
for i, record in enumerate(batch_results):
last_request = record.requests[-1]
if last_request.failed():
error_events = [
e
for e in last_request.events
if e.type
in (
DynamicInferenceEventType.ERROR_NONTRANSIENT,
DynamicInferenceEventType.ERROR_TRANSIENT,
)
]
error_msg = str(error_events[-1].payload) if error_events else "Unknown error"
failed_errors.append(f"Request {i}: {error_msg}")

if failed_errors:
error_detail = "; ".join(failed_errors)
logger.error(f"Inference request(s) failed: {error_detail}")
return Response(f"Inference request(s) failed: {error_detail}", status=400)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: HTTP 400 indicates a client error ("bad request"), but inference failures can also be server-side (e.g. ERROR_TRANSIENT). Returning 400 for transient/server errors is misleading to clients — they may not retry when they should. Consider using 500 (or 503 for transient errors) instead, or at minimum differentiating based on the event type since you already distinguish ERROR_TRANSIENT vs ERROR_NONTRANSIENT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Claude, thank you for your feedback. I have addressed this concern!


# --- 5. Format OpenAI Response ---
choices = []
total_completion_tokens = 0
prompt_tokens_counts = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import time

from megatron.core.inference.inference_request import DynamicInferenceEventType
from megatron.core.inference.sampling_params import SamplingParams

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -124,7 +125,29 @@ async def completions():
f"{time.perf_counter() - start_time:.2f}s"
)

# --- 4. Format Response (matching old_completions.py) ---
# --- 4. Check for failed requests ---
failed_errors = []
for i, record in enumerate(batch_results):
last_request = record.requests[-1]
if last_request.failed():
error_events = [
e
for e in last_request.events
if e.type
in (
DynamicInferenceEventType.ERROR_NONTRANSIENT,
DynamicInferenceEventType.ERROR_TRANSIENT,
)
]
error_msg = str(error_events[-1].payload) if error_events else "Unknown error"
failed_errors.append(f"Request {i}: {error_msg}")

if failed_errors:
error_detail = "; ".join(failed_errors)
logger.error(f"Inference request(s) failed: {error_detail}")
return f"Inference request(s) failed: {error_detail}", 400
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same HTTP 400 concern as in chat_completions.py — transient/server-side errors should not be reported as 400.

Also minor inconsistency: chat_completions.py returns Response(msg, status=400) while this file returns a tuple (msg, 400). Both work in Flask/Quart, but it would be cleaner to use the same style (the existing exception handler on line 120 already uses the tuple form, so the tuple is fine here — just noting the cross-file inconsistency).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See other thread.


# --- 5. Format Response (matching old_completions.py) ---
choices = []

request_idx = 0
Expand Down
Loading