diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 2a41d9cf1cd..f18e0d339f5 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -866,7 +866,10 @@ def _add_request( self.failed_request_ids.append(request_id) if self.rank == 0: warnings.warn( - f"Request {request_id} failed to be added to the engine due to errors." + f"Request {request_id} failed to be added to the engine due to errors. " + f"Prompt Tokens: {len(request.prompt_tokens)} " + f"Tokens to generate: {request.sampling_params.num_tokens_to_generate} " + f"Max sequence length: {self.context.max_sequence_length} " ) return self.requests[request_id].future diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py index 532a2b9b5aa..ae26444bada 100644 --- a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/endpoints/chat_completions.py @@ -13,6 +13,159 @@ logger = logging.getLogger(__name__) +# pylint: disable=line-too-long + + +def _get_field(obj, key, default=None): + """Read a field from dict-like or object-like values.""" + if isinstance(obj, dict): + return obj.get(key, default) + return getattr(obj, key, default) + + +def _normalize_tool_calls(tool_calls): + """Normalize tool calls to OpenAI-compatible JSON primitives.""" + normalized = [] + for call in tool_calls or []: + fn = _get_field(call, "function", {}) or {} + fn_name = _get_field(fn, "name") + fn_args = _get_field(fn, "arguments", "") + if fn_name is None: + continue + if not isinstance(fn_args, str): + try: + fn_args = json.dumps(fn_args, ensure_ascii=False) + except TypeError: + fn_args = str(fn_args) + normalized.append( + { + "id": str(_get_field(call, "id", f"call_{uuid.uuid4().hex[:24]}")), + "type": "function", + "function": {"name": str(fn_name), "arguments": fn_args}, + } + ) + return normalized + + +def _coerce_arguments_mapping(arguments): + """Coerce function.arguments to a mapping for HF/Jinja chat templates. + + Examples: + - {"x": 1} -> {"x": 1} + - '{"x": 1}' -> {"x": 1} + - "[1, 2]" -> {} # JSON parses, but not a mapping + - "not-json" -> {} + - None -> {} + """ + if isinstance(arguments, dict): + return arguments + if isinstance(arguments, str): + try: + parsed = json.loads(arguments) + except (TypeError, ValueError): + return {} + return parsed if isinstance(parsed, dict) else {} + return {} + + +def _sanitize_messages_for_template(messages): + """Prepare messages so tokenizer chat templates can safely consume them. + + This only normalizes tool-call argument payloads inside each message: + - messages[*].tool_calls[*].function.arguments is coerced to a dict. + + Example transformation: + Input: + [{"role": "assistant", "tool_calls": [{"function": {"name": "f", "arguments": "{\"x\": 1}"}}]}] + Output: + [{"role": "assistant", "tool_calls": [{"function": {"name": "f", "arguments": {"x": 1}}}]}] + + Another example: + - arguments: "[1,2,3]" -> arguments: {} + """ + if not isinstance(messages, list): + return messages + sanitized = [] + for message in messages: + if not isinstance(message, dict): + sanitized.append(message) + continue + msg_copy = dict(message) + tool_calls = msg_copy.get("tool_calls") + if isinstance(tool_calls, list): + sanitized_tool_calls = [] + for call in tool_calls: + if not isinstance(call, dict): + sanitized_tool_calls.append(call) + continue + call_copy = dict(call) + function = call_copy.get("function") + if isinstance(function, dict): + function_copy = dict(function) + function_copy["arguments"] = _coerce_arguments_mapping( + function_copy.get("arguments", {}) + ) + call_copy["function"] = function_copy + sanitized_tool_calls.append(call_copy) + msg_copy["tool_calls"] = sanitized_tool_calls + sanitized.append(msg_copy) + return sanitized + + +def _sanitize_tools_for_template(tools): + """Ensure tools payload is template-safe and has mapping parameters. + + Example transformations: + - {"function": {"name": "f", "parameters": "not-a-dict"}} + -> {"function": {"name": "f", "parameters": {"type": "object", "properties": {}}}} + - non-dict tool entries are dropped. + - non-list input returns None. + """ + if not isinstance(tools, list): + return None + + sanitized = [] + for tool in tools: + if not isinstance(tool, dict): + continue + tool_copy = dict(tool) + function = tool_copy.get("function") + if isinstance(function, dict): + function_copy = dict(function) + if not isinstance(function_copy.get("parameters"), dict): + function_copy["parameters"] = {"type": "object", "properties": {}} + tool_copy["function"] = function_copy + sanitized.append(tool_copy) + return sanitized + + +def _replace_prefix_tokens( + eos_token_id, + previous_turn_token_ids, + retokeenized_previous_turn_token_ids, + current_turn_token_ids, +): + """Replace the token ids that are associated with the previous turn with the actual tokens + from the previous generation (rather than the ones from the chat template application).""" + + # Strip the EOS from the previous turn token ids if it exists + if previous_turn_token_ids[-1] == eos_token_id: + previous_turn_token_ids = previous_turn_token_ids[:-1] + + # Find the last EOS token id in the previous turn token ids + last_eos_token_id_index = len(retokeenized_previous_turn_token_ids) - 1 + for i in reversed(range(len(retokeenized_previous_turn_token_ids))): + if current_turn_token_ids[i] == eos_token_id: + last_eos_token_id_index = i + break + + # Replace the current turn token ids with the tokens from the previous generation + current_turn_additional_token_ids = current_turn_token_ids[last_eos_token_id_index:] + + # Return the previous turn token ids + the current turn token ids + return previous_turn_token_ids + current_turn_additional_token_ids + + try: import orjson @@ -26,18 +179,29 @@ bp = Blueprint('chat_completions_api', __name__) - def apply_parsers(text, tools, parsers_list): + def apply_parsers(message_text, tools, parsers_list, tools_requested): """Runs CPU-intensive text parsing.""" meta = {} for parser in parsers_list: if parser not in PARSER_MAPPING: raise ValueError(f"Parser {parser} not found in PARSER_MAPPING") - text, new_info = PARSER_MAPPING[parser].parse(text, tools=tools) + + prev_text = message_text + parsed_text, new_info = PARSER_MAPPING[parser].parse(message_text, tools=tools) + if "tool_calls" in new_info: + new_info["tool_calls"] = _normalize_tool_calls(new_info.get("tool_calls", [])) + if not tools_requested: + # Ignore incidental tool-call syntax in plain chat mode. + parsed_text = prev_text + new_info.pop("tool_calls", None) + message_text = parsed_text + assert not ( meta.keys() & new_info.keys() ), "Multiple parsers found the same information." meta.update(new_info) - return text, meta + + return message_text, meta @bp.route('/chat/completions', methods=['POST']) @bp.route('/v1/chat/completions', methods=['POST']) @@ -48,42 +212,92 @@ async def chat_completions(): parsers = current_app.config['parsers'] req = await request.get_json() - - # --- 1. Parse Messages --- + tools = req.get("tools", None) + tools_requested = bool(tools) messages = req.get("messages") + chat_template_kwargs = req.get("chat_template_kwargs", {}) + if not isinstance(chat_template_kwargs, dict): + logger.warning( + "Ignoring non-dict chat_template_kwargs: %s", type(chat_template_kwargs).__name__ + ) + chat_template_kwargs = {} + # --- 1. Parse Messages --- if not messages: return Response("Missing 'messages' field", status=400) if not isinstance(messages, list): return Response("'messages' must be a list", status=400) - - # The OpenAI spec sends tool_call arguments as a JSON string, but - # Jinja chat templates iterate over them with |items, requiring a dict. - for msg in messages: - if msg.get("tool_calls"): - for tc in msg["tool_calls"]: - fn = tc.get("function", tc) - args = fn.get("arguments") - if isinstance(args, str): - try: - fn["arguments"] = json.loads(args) - except (json.JSONDecodeError, TypeError): - pass + template_messages = _sanitize_messages_for_template(messages) + template_tools = _sanitize_tools_for_template(tools) try: - prompt_tokens = tokenizer.apply_chat_template( - messages, - tokenize=True, - add_generation_prompt=True, - tools=req.get("tools", None), - **req.get("chat_template_kwargs", {}), - ) - except (AttributeError, AssertionError): - warnings.warn( - "Tokenizer does not support 'apply_chat_template'. Using tokenize instead." - ) - prompt_tokens = tokenizer.tokenize( - "\n".join([message["content"] for message in messages]) - ) + if hasattr(tokenizer, 'apply_chat_template'): + prompt_tokens = tokenizer.apply_chat_template( + template_messages, + tokenize=True, + add_generation_prompt=True, + tools=template_tools, + **chat_template_kwargs, + ) + + if req.get("prevent_retokenization", True): + # If we are avoiding retokenization, we need to replace some prompt tokens with the prompt/generation tokens from the previous generation + # This improves prefix cache hits and reduces logprob variation between training and inference. + + eos_token_id = tokenizer.eos_id + assert eos_token_id is not None, "Your tokenizer must have an EOS token ID!" + + warnings.warn( + "Avoiding prefix retokenization." + " This is a patch that ensures subsequent generations are not retokenized differently than the previous generation." + " This may cause unexpected behavior if messages (including system messages) are altered between generations." + ) + + # Find the last assistant message + last_assistant_message_idx = None + for i in reversed(range(len(template_messages))): + if template_messages[i]["role"] == "assistant": + last_assistant_message_idx = i + break + + # If there was a previous assistant message, we need to replace the prefix tokens with the tokens from the previous generation + if last_assistant_message_idx is not None: + messages_to_last_assistant_message = template_messages[ + : last_assistant_message_idx + 1 + ] + + # Get the templated tokenization of just the previous generation + retokenized_previous_turn_token_ids = tokenizer.apply_chat_template( + messages_to_last_assistant_message, + tokenize=True, + add_generation_prompt=False, + tools=template_tools, + **chat_template_kwargs, + ) + + # Replace the prefix tokens with the tokens from the previous generation + last_assistant_message = template_messages[last_assistant_message_idx] + assert ( + "prompt_token_ids" in last_assistant_message + and "generation_token_ids" in last_assistant_message + ), "Last assistant message must have prompt_token_ids and generation_token_ids from previous generation to avoid prefix retokenization" + previous_turn_token_ids = ( + last_assistant_message["prompt_token_ids"] + + last_assistant_message["generation_token_ids"] + ) + prompt_tokens = _replace_prefix_tokens( + eos_token_id, + previous_turn_token_ids, + retokenized_previous_turn_token_ids, + prompt_tokens, + ) + + else: + warnings.warn( + "Tokenizer does not support 'apply_chat_template'. Using tokenize instead." + ) + prompt_tokens = tokenizer.tokenize( + "\n".join([message["content"] for message in messages]) + ) except Exception as e: logger.error(f"{traceback.format_exc()}") return Response(f"Error processing 'messages': {e}", status=500) @@ -164,7 +378,19 @@ async def chat_completions(): k: v[1] if isinstance(v, (list, tuple)) and len(v) == 2 and v[0] == "tensor" else v for k, v in result.items() } - prompt_tokens_out = result["prompt_tokens"] + + if result["status"] == "FAILED": + if result["sampling_params"]["num_tokens_to_generate"] <= 0: + return Response( + f"Request {request_idx} failed due to context length overflow", status=400 + ) + else: + return Response( + f"Request {request_idx} failed due to internal error {result['events']}", + status=500, + ) + + prompt_tokens_out = result["prompt_tokens"] # The engine can modify prompt_tokens. text_output = result["generated_text"] prompt_tokens_count = len(prompt_tokens_out) if prompt_tokens_out is not None else 0 prompt_tokens_counts.append(prompt_tokens_count) @@ -208,11 +434,11 @@ async def chat_completions(): if parsers: message_text, metadata = apply_parsers( - message_text, req.get("tools", None), parsers + message_text, req.get("tools", None), parsers, tools_requested ) message = {"role": "assistant", "content": message_text} - if "tool_calls" in metadata: + if metadata.get("tool_calls", []): message["tool_calls"] = metadata["tool_calls"] if "reasoning" in metadata: message["reasoning"] = metadata["reasoning"] @@ -223,20 +449,12 @@ async def chat_completions(): message["generation_log_probs"] = result.get("generated_log_probs", []) return_log_probs = sampling_params.return_log_probs - gen_length = result.get("generated_length") or len(result.get("generated_tokens", [])) - max_gen = result.get("sampling_params", {}) - if isinstance(max_gen, dict): - max_gen = max_gen.get("num_tokens_to_generate", None) - elif hasattr(max_gen, "num_tokens_to_generate"): - max_gen = max_gen.num_tokens_to_generate - else: - max_gen = None - if metadata.get("tool_calls", []): - finish_reason = "tool_calls" - elif max_gen is not None and gen_length >= max_gen: + finish_reason = "tool_calls" if metadata.get("tool_calls", []) else "stop" + if ( + len(result["generated_tokens"]) + >= result["sampling_params"]["num_tokens_to_generate"] + ): finish_reason = "length" - else: - finish_reason = "stop" choice_data = { "index": request_idx, @@ -245,10 +463,10 @@ async def chat_completions(): "generation_token_ids": result["generated_tokens"], "generation_log_probs": result.get("generated_log_probs", []), "raw_text": result["prompt"] + result["generated_text"], - "logprobs": ( - {"content": logprobs_content} if sampling_params.return_log_probs else None - ), - "finish_reason": "tool_calls" if metadata.get("tool_calls", []) else finish_reason, + # 'logprobs' in chat API is an object containing 'content' + # "logprobs": {"content": logprobs_content} if logprobs_content else None, + "logprobs": {"content": logprobs_content} if return_log_probs else None, + "finish_reason": finish_reason, } choice_data["policy_staleness"] = result["policy_staleness"] choice_data["kv_cache_staleness"] = result["kv_cache_staleness"] @@ -266,6 +484,10 @@ async def chat_completions(): ] choices.append(choice_data) + if choice_data["generation_log_probs"] is None: + logger.warning( + "Generation log probs is None for request:\n%s", json.dumps(result, indent=4) + ) total_completion_tokens += len(result["generated_tokens"]) request_idx += 1 diff --git a/megatron/core/ssm/gated_delta_net.py b/megatron/core/ssm/gated_delta_net.py index 7b1149a781d..e23599689bb 100644 --- a/megatron/core/ssm/gated_delta_net.py +++ b/megatron/core/ssm/gated_delta_net.py @@ -174,8 +174,10 @@ def __init__( dtype=config.params_dtype, ) setattr(self.conv1d.weight, "tensor_model_parallel", True) + setattr(self.conv1d.weight, "partition_dim", 0) if conv_bias: setattr(self.conv1d.bias, "tensor_model_parallel", True) + setattr(self.conv1d.bias, "partition_dim", 0) # Time step projection (discretization) self.num_v_heads_local_tp = self.num_value_heads // self.tp_size @@ -188,6 +190,7 @@ def __init__( ) ) setattr(self.dt_bias, "tensor_model_parallel", True) + setattr(self.dt_bias, "partition_dim", 0) # A_log parameter self.A_log = nn.Parameter( torch.empty( @@ -197,6 +200,7 @@ def __init__( ) ) setattr(self.A_log, "tensor_model_parallel", True) + setattr(self.A_log, "partition_dim", 0) # Output layernorm before projection self.out_norm = build_module( diff --git a/megatron/core/ssm/mamba_mixer.py b/megatron/core/ssm/mamba_mixer.py index 6c2395ded94..d387802dea3 100644 --- a/megatron/core/ssm/mamba_mixer.py +++ b/megatron/core/ssm/mamba_mixer.py @@ -273,7 +273,9 @@ def __init__( dtype=config.params_dtype, ) setattr(self.conv1d.weight, "tensor_model_parallel", True) + setattr(self.conv1d.weight, "partition_dim", 0) setattr(self.conv1d.bias, "tensor_model_parallel", True) + setattr(self.conv1d.bias, "partition_dim", 0) if self.config.perform_initialization: if self.conv_init is not None: nn.init.uniform_(self.conv1d.weight, -self.conv_init, self.conv_init) @@ -298,6 +300,7 @@ def __init__( inv_dt = dt + torch.log(-torch.expm1(-dt)) self.dt_bias = nn.Parameter(inv_dt) setattr(self.dt_bias, "tensor_model_parallel", True) + setattr(self.dt_bias, "partition_dim", 0) # A parameter assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0] @@ -309,6 +312,7 @@ def __init__( A_log = torch.log(A) # Keep A_log in fp32 self.A_log = nn.Parameter(A_log) setattr(self.A_log, "tensor_model_parallel", True) + setattr(self.A_log, "partition_dim", 0) # D "skip" parameter self.D = nn.Parameter( @@ -318,6 +322,7 @@ def __init__( ) ) # Keep in fp32 setattr(self.D, "tensor_model_parallel", True) + setattr(self.D, "partition_dim", 0) if self.rmsnorm: assert RMSNormGated is not None @@ -330,6 +335,7 @@ def __init__( dtype=config.params_dtype, ) setattr(self.norm.weight, "tensor_model_parallel", True) + setattr(self.norm.weight, "partition_dim", 0) # Assume sequence parallelism: input is partitioned along d_inner and # output is partitioned along the sequence dimension