Skip to content
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
ff01c9a
Fix the coordinator signaling logic
tdene Feb 19, 2026
b3e59f0
Fix coordinator shutdown
tdene Feb 20, 2026
f8c8025
Fix coordinator unit tests
tdene Feb 19, 2026
fc9d093
Address reviewer comments
tdene Feb 24, 2026
8d6eeee
Address reviewer comments
tdene Feb 24, 2026
9d0ff35
Address reviewer comments
tdene Feb 24, 2026
bb961a3
Cleanup
tdene Feb 24, 2026
a87ea0b
Address reviewer comments
tdene Feb 24, 2026
d0c19a9
Cleanup
tdene Feb 24, 2026
3fb373c
Fix cleanup
tdene Feb 24, 2026
291b820
Cleanup
tdene Feb 24, 2026
683312f
Cleanup
tdene Feb 24, 2026
526ccb2
Cleanup
tdene Feb 24, 2026
178680c
Correct shutdown
tdene Feb 24, 2026
8c426d7
Correct pause, no ACK
tdene Feb 24, 2026
2e2c8d5
lint
tdene Feb 25, 2026
e8d55b9
Fix cleanup
tdene Feb 25, 2026
68e3c6f
Fix cleanup
tdene Feb 25, 2026
5a3e1bb
Clean up test shutdown
tdene Feb 25, 2026
611de48
Cleanup
tdene Feb 26, 2026
b3605b8
Fix typo
tdene Feb 26, 2026
45b7436
tmp
tdene Feb 26, 2026
8081aff
Final cleanup
tdene Feb 26, 2026
a7f2dab
Cleanup tests
tdene Feb 26, 2026
aac10bb
Fix golden values
tdene Feb 26, 2026
690ab38
lint
tdene Feb 26, 2026
cee4c28
Fix CI unit test
tdene Feb 26, 2026
97d3343
Address reviewer comments
tdene Feb 26, 2026
136cfdc
Fix CI
tdene Feb 27, 2026
dbc7117
Address reviewer comments
tdene Feb 27, 2026
9ce666b
Merge remote-tracking branch 'gh/main' into tde/robust_coordinator_si…
tdene Feb 27, 2026
6d3afc0
Remove immediate shutdown logic
tdene Feb 27, 2026
45ffc2b
Fix CI
tdene Feb 28, 2026
3eb054f
Address reviewer comments
tdene Feb 28, 2026
f8639d0
lint
tdene Feb 28, 2026
9a17764
Fix transposed if statement
tdene Feb 28, 2026
bdfe52c
Fix spurious CI hang
tdene Mar 1, 2026
13cf3ac
Add UNPAUSING state
tdene Mar 2, 2026
6e93217
Merge remote-tracking branch 'gh/main' into tde/robust_coordinator_si…
tdene Mar 2, 2026
ee4f148
Add functional test
tdene Mar 2, 2026
ac01210
Minor inference changes for NemoRL
ArEsKay3 Mar 3, 2026
e91e0e2
lint
tdene Mar 3, 2026
d025a9c
Merge remote-tracking branch 'gh/main' into tde/robust_coordinator_si…
tdene Mar 3, 2026
c3a315e
Fix lack of barrier after resume
tdene Mar 4, 2026
78af3ca
Merge remote-tracking branch 'gh/main' into tde/robust_coordinator_si…
tdene Mar 4, 2026
9de851d
Fix mamba states to float 32
Mar 4, 2026
742327c
Merge remote-tracking branch 'ArEsKay3/rkirby/minor_inference_patches…
Mar 5, 2026
1e44948
Fix stuff for TP
Mar 6, 2026
ccdf4bc
Normalize tool_calls and gate parser tool-calls to tool-enabled requests
i-riyad Mar 4, 2026
c38ff24
Harden chat template inputs for tool-calling payloads.
i-riyad Mar 6, 2026
64aee21
Fix finish reason, add support for preventing retokenization of prompt
ArEsKay3 Mar 9, 2026
3e9b488
Fixing merge conflicts
shanmugamr1992 Mar 10, 2026
a603cf5
Fixing merge conflicts
shanmugamr1992 Mar 10, 2026
bbefd47
Merge branch 'main' into shanmugamr1992/megatron_inference_ultra
shanmugamr1992 Mar 10, 2026
f71f013
Merge branch 'main' into shanmugamr1992/megatron_inference_ultra
shanmugamr1992 Mar 10, 2026
64cf04b
Fixing pylint stuff
shanmugamr1992 Mar 11, 2026
d43aeca
Merge branch 'main' into shanmugamr1992/megatron_inference_ultra
shanmugamr1992 Mar 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,10 @@ def _add_request(
self.failed_request_ids.append(request_id)
if self.rank == 0:
warnings.warn(
f"Request {request_id} failed to be added to the engine due to errors."
f"Request {request_id} failed to be added to the engine due to errors. " \
f"Prompt Tokens: {len(request.prompt_tokens)} " \
f"Tokens to generate: {request.sampling_params.num_tokens_to_generate} " \
f"Max sequence length: {self.context.max_sequence_length} "
)

return self.requests[request_id].future
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,158 @@
import warnings

from megatron.core.inference.sampling_params import SamplingParams
from megatron.core.inference.inference_request import DynamicInferenceRequest
from megatron.core.tokenizers.text.parsers import PARSER_MAPPING

logger = logging.getLogger(__name__)


def _get_field(obj, key, default=None):
"""Read a field from dict-like or object-like values."""
if isinstance(obj, dict):
return obj.get(key, default)
return getattr(obj, key, default)


def _normalize_tool_calls(tool_calls):
"""Normalize tool calls to OpenAI-compatible JSON primitives."""
normalized = []
for call in tool_calls or []:
fn = _get_field(call, "function", {}) or {}
fn_name = _get_field(fn, "name")
fn_args = _get_field(fn, "arguments", "")
if fn_name is None:
continue
if not isinstance(fn_args, str):
try:
fn_args = json.dumps(fn_args, ensure_ascii=False)
except TypeError:
fn_args = str(fn_args)
normalized.append(
{
"id": str(_get_field(call, "id", f"call_{uuid.uuid4().hex[:24]}")),
"type": "function",
"function": {"name": str(fn_name), "arguments": fn_args},
}
)
return normalized


def _coerce_arguments_mapping(arguments):
"""Coerce function.arguments to a mapping for HF/Jinja chat templates.

Examples:
- {"x": 1} -> {"x": 1}
- '{"x": 1}' -> {"x": 1}
- "[1, 2]" -> {} # JSON parses, but not a mapping
- "not-json" -> {}
- None -> {}
"""
if isinstance(arguments, dict):
return arguments
if isinstance(arguments, str):
try:
parsed = json.loads(arguments)
except (TypeError, ValueError):
return {}
return parsed if isinstance(parsed, dict) else {}
return {}


def _sanitize_messages_for_template(messages):
"""Prepare messages so tokenizer chat templates can safely consume them.

This only normalizes tool-call argument payloads inside each message:
- messages[*].tool_calls[*].function.arguments is coerced to a dict.

Example transformation:
Input:
[{"role": "assistant", "tool_calls": [{"function": {"name": "f", "arguments": "{\"x\": 1}"}}]}]
Output:
[{"role": "assistant", "tool_calls": [{"function": {"name": "f", "arguments": {"x": 1}}}]}]

Another example:
- arguments: "[1,2,3]" -> arguments: {}
"""
if not isinstance(messages, list):
return messages
sanitized = []
for message in messages:
if not isinstance(message, dict):
sanitized.append(message)
continue
msg_copy = dict(message)
tool_calls = msg_copy.get("tool_calls")
if isinstance(tool_calls, list):
sanitized_tool_calls = []
for call in tool_calls:
if not isinstance(call, dict):
sanitized_tool_calls.append(call)
continue
call_copy = dict(call)
function = call_copy.get("function")
if isinstance(function, dict):
function_copy = dict(function)
function_copy["arguments"] = _coerce_arguments_mapping(function_copy.get("arguments", {}))
call_copy["function"] = function_copy
sanitized_tool_calls.append(call_copy)
msg_copy["tool_calls"] = sanitized_tool_calls
sanitized.append(msg_copy)
return sanitized


def _sanitize_tools_for_template(tools):
"""Ensure tools payload is template-safe and has mapping parameters.

Example transformations:
- {"function": {"name": "f", "parameters": "not-a-dict"}}
-> {"function": {"name": "f", "parameters": {"type": "object", "properties": {}}}}
- non-dict tool entries are dropped.
- non-list input returns None.
"""
if not isinstance(tools, list):
return None

sanitized = []
for tool in tools:
if not isinstance(tool, dict):
continue
tool_copy = dict(tool)
function = tool_copy.get("function")
if isinstance(function, dict):
function_copy = dict(function)
if not isinstance(function_copy.get("parameters"), dict):
function_copy["parameters"] = {"type": "object", "properties": {}}
tool_copy["function"] = function_copy
sanitized.append(tool_copy)
return sanitized

def _replace_prefix_tokens(
eos_token_id,
previous_turn_token_ids,
retokeenized_previous_turn_token_ids,
current_turn_token_ids
):
"""Replace the token ids that are associated with the previous turn with the actual tokens
from the previous generation (rather than the ones from the chat template application)."""

# Strip the EOS from the previous turn token ids if it exists
if previous_turn_token_ids[-1] == eos_token_id:
previous_turn_token_ids = previous_turn_token_ids[:-1]

# Find the last EOS token id in the previous turn token ids
last_eos_token_id_index = len(retokeenized_previous_turn_token_ids) - 1
for i in reversed(range(len(retokeenized_previous_turn_token_ids))):
if current_turn_token_ids[i] == eos_token_id:
last_eos_token_id_index = i
break

# Replace the current turn token ids with the tokens from the previous generation
current_turn_additional_token_ids = current_turn_token_ids[last_eos_token_id_index:]

# Return the previous turn token ids + the current turn token ids
return previous_turn_token_ids + current_turn_additional_token_ids

try:
import orjson

Expand All @@ -26,18 +174,31 @@

bp = Blueprint('chat_completions_api', __name__)

def apply_parsers(text, tools, parsers_list):
def apply_parsers(message_text, tools, parsers_list, tools_requested):
"""Runs CPU-intensive text parsing."""
meta = {}
for parser in parsers_list:
if parser not in PARSER_MAPPING:
raise ValueError(f"Parser {parser} not found in PARSER_MAPPING")
text, new_info = PARSER_MAPPING[parser].parse(text, tools=tools)

prev_text = message_text
parsed_text, new_info = PARSER_MAPPING[parser].parse(
message_text, tools=tools
)
if "tool_calls" in new_info:
new_info["tool_calls"] = _normalize_tool_calls(new_info.get("tool_calls", []))
if not tools_requested:
# Ignore incidental tool-call syntax in plain chat mode.
parsed_text = prev_text
new_info.pop("tool_calls", None)
message_text = parsed_text

assert not (
meta.keys() & new_info.keys()
), "Multiple parsers found the same information."
meta.update(new_info)
return text, meta

return message_text, meta

@bp.route('/chat/completions', methods=['POST'])
@bp.route('/v1/chat/completions', methods=['POST'])
Expand All @@ -48,42 +209,84 @@ async def chat_completions():
parsers = current_app.config['parsers']

req = await request.get_json()
tools = req.get("tools", None)
tools_requested = bool(tools)
messages = req.get("messages")
chat_template_kwargs = req.get("chat_template_kwargs", {})
if not isinstance(chat_template_kwargs, dict):
logger.warning("Ignoring non-dict chat_template_kwargs: %s", type(chat_template_kwargs).__name__)
chat_template_kwargs = {}\

# --- 1. Parse Messages ---
messages = req.get("messages")
if not messages:
return Response("Missing 'messages' field", status=400)
if not isinstance(messages, list):
return Response("'messages' must be a list", status=400)

# The OpenAI spec sends tool_call arguments as a JSON string, but
# Jinja chat templates iterate over them with |items, requiring a dict.
for msg in messages:
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
fn = tc.get("function", tc)
args = fn.get("arguments")
if isinstance(args, str):
try:
fn["arguments"] = json.loads(args)
except (json.JSONDecodeError, TypeError):
pass
template_messages = _sanitize_messages_for_template(messages)
template_tools = _sanitize_tools_for_template(tools)

try:
prompt_tokens = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
tools=req.get("tools", None),
**req.get("chat_template_kwargs", {}),
)
except (AttributeError, AssertionError):
warnings.warn(
"Tokenizer does not support 'apply_chat_template'. Using tokenize instead."
)
prompt_tokens = tokenizer.tokenize(
"\n".join([message["content"] for message in messages])
)
if hasattr(tokenizer, 'apply_chat_template'):
prompt_tokens = tokenizer.apply_chat_template(
template_messages,
tokenize=True,
add_generation_prompt=True,
tools=template_tools,
**chat_template_kwargs,
)

if req.get("prevent_retokenization", True):
# If we are avoiding retokenization, we need to replace some prompt tokens with the prompt/generation tokens from the previous generation
# This improves prefix cache hits and reduces logprob variation between training and inference.

eos_token_id = tokenizer.eos_id
assert eos_token_id is not None, "Your tokenizer must have an EOS token ID!"

warnings.warn(
"Avoiding prefix retokenization." \
" This is a patch that ensures subsequent generations are not retokenized differently than the previous generation." \
" This may cause unexpected behavior if messages (including system messages) are altered between generations."
)

# Find the last assistant message
last_assistant_message_idx = None
for i in reversed(range(len(template_messages))):
if template_messages[i]["role"] == "assistant":
last_assistant_message_idx = i
break

# If there was a previous assistant message, we need to replace the prefix tokens with the tokens from the previous generation
if last_assistant_message_idx is not None:
messages_to_last_assistant_message = template_messages[: last_assistant_message_idx + 1]

# Get the templated tokenization of just the previous generation
retokenized_previous_turn_token_ids = tokenizer.apply_chat_template(
messages_to_last_assistant_message,
tokenize=True,
add_generation_prompt=False,
tools=template_tools,
**chat_template_kwargs,
)

# Replace the prefix tokens with the tokens from the previous generation
last_assistant_message = template_messages[last_assistant_message_idx]
assert "prompt_token_ids" in last_assistant_message and "generation_token_ids" in last_assistant_message, \
"Last assistant message must have prompt_token_ids and generation_token_ids from previous generation to avoid prefix retokenization"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can you split this into multiple lines?

previous_turn_token_ids = last_assistant_message["prompt_token_ids"] + last_assistant_message["generation_token_ids"]
prompt_tokens = _replace_prefix_tokens(
eos_token_id,
previous_turn_token_ids,
retokenized_previous_turn_token_ids,
prompt_tokens,
)

else:
warnings.warn(
"Tokenizer does not support 'apply_chat_template'. Using tokenize instead."
)
prompt_tokens = tokenizer.tokenize(
"\n".join([message["content"] for message in messages])
)
except Exception as e:
logger.error(f"{traceback.format_exc()}")
return Response(f"Error processing 'messages': {e}", status=500)
Expand Down Expand Up @@ -164,7 +367,14 @@ async def chat_completions():
k: v[1] if isinstance(v, (list, tuple)) and len(v) == 2 and v[0] == "tensor" else v
for k, v in result.items()
}
prompt_tokens_out = result["prompt_tokens"]

if result["status"] == "FAILED":
if result["sampling_params"]["num_tokens_to_generate"] <= 0:
return Response(f"Request {request_idx} failed due to context length overflow", status=400)
else:
return Response(f"Request {request_idx} failed due to internal error {result["events"]}", status=500)

prompt_tokens_out = result["prompt_tokens"] # The engine can modify prompt_tokens.
text_output = result["generated_text"]
prompt_tokens_count = len(prompt_tokens_out) if prompt_tokens_out is not None else 0
prompt_tokens_counts.append(prompt_tokens_count)
Expand Down Expand Up @@ -208,11 +418,11 @@ async def chat_completions():

if parsers:
message_text, metadata = apply_parsers(
message_text, req.get("tools", None), parsers
message_text, req.get("tools", None), parsers, tools_requested
)

message = {"role": "assistant", "content": message_text}
if "tool_calls" in metadata:
if metadata.get("tool_calls", []):
message["tool_calls"] = metadata["tool_calls"]
if "reasoning" in metadata:
message["reasoning"] = metadata["reasoning"]
Expand All @@ -223,20 +433,9 @@ async def chat_completions():
message["generation_log_probs"] = result.get("generated_log_probs", [])
return_log_probs = sampling_params.return_log_probs

gen_length = result.get("generated_length") or len(result.get("generated_tokens", []))
max_gen = result.get("sampling_params", {})
if isinstance(max_gen, dict):
max_gen = max_gen.get("num_tokens_to_generate", None)
elif hasattr(max_gen, "num_tokens_to_generate"):
max_gen = max_gen.num_tokens_to_generate
else:
max_gen = None
if metadata.get("tool_calls", []):
finish_reason = "tool_calls"
elif max_gen is not None and gen_length >= max_gen:
finish_reason = "tool_calls" if metadata.get("tool_calls", []) else "stop"
if len(result["generated_tokens"]) >= result["sampling_params"]["num_tokens_to_generate"]:
finish_reason = "length"
else:
finish_reason = "stop"

choice_data = {
"index": request_idx,
Expand All @@ -245,10 +444,10 @@ async def chat_completions():
"generation_token_ids": result["generated_tokens"],
"generation_log_probs": result.get("generated_log_probs", []),
"raw_text": result["prompt"] + result["generated_text"],
"logprobs": (
{"content": logprobs_content} if sampling_params.return_log_probs else None
),
"finish_reason": "tool_calls" if metadata.get("tool_calls", []) else finish_reason,
# 'logprobs' in chat API is an object containing 'content'
# "logprobs": {"content": logprobs_content} if logprobs_content else None,
Comment on lines +466 to +467
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove these lines?

"logprobs": {"content": logprobs_content} if return_log_probs else None,
"finish_reason": finish_reason,
}
choice_data["policy_staleness"] = result["policy_staleness"]
choice_data["kv_cache_staleness"] = result["kv_cache_staleness"]
Expand All @@ -266,6 +465,8 @@ async def chat_completions():
]

choices.append(choice_data)
if choice_data["generation_log_probs"] is None:
print(f"Generation log probs is None for request:\n{json.dumps(result, indent=4)}", flush=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we put this in a logging statement? Also does this indicate a bug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Put this to loging

total_completion_tokens += len(result["generated_tokens"])
request_idx += 1

Expand Down
Loading
Loading