-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Shanmugamr1992/megatron inference ultra #3784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 55 commits
ff01c9a
b3e59f0
f8c8025
fc9d093
8d6eeee
9d0ff35
bb961a3
a87ea0b
d0c19a9
3fb373c
291b820
683312f
526ccb2
178680c
8c426d7
2e2c8d5
e8d55b9
68e3c6f
5a3e1bb
611de48
b3605b8
45b7436
8081aff
a7f2dab
aac10bb
690ab38
cee4c28
97d3343
136cfdc
dbc7117
9ce666b
6d3afc0
45ffc2b
3eb054f
f8639d0
9a17764
bdfe52c
13cf3ac
6e93217
ee4f148
ac01210
e91e0e2
d025a9c
c3a315e
78af3ca
9de851d
742327c
1e44948
ccdf4bc
c38ff24
64aee21
3e9b488
a603cf5
bbefd47
f71f013
64cf04b
d43aeca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,10 +9,158 @@ | |
| import warnings | ||
|
|
||
| from megatron.core.inference.sampling_params import SamplingParams | ||
| from megatron.core.inference.inference_request import DynamicInferenceRequest | ||
| from megatron.core.tokenizers.text.parsers import PARSER_MAPPING | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _get_field(obj, key, default=None): | ||
| """Read a field from dict-like or object-like values.""" | ||
| if isinstance(obj, dict): | ||
| return obj.get(key, default) | ||
| return getattr(obj, key, default) | ||
|
|
||
|
|
||
| def _normalize_tool_calls(tool_calls): | ||
| """Normalize tool calls to OpenAI-compatible JSON primitives.""" | ||
| normalized = [] | ||
| for call in tool_calls or []: | ||
| fn = _get_field(call, "function", {}) or {} | ||
| fn_name = _get_field(fn, "name") | ||
| fn_args = _get_field(fn, "arguments", "") | ||
| if fn_name is None: | ||
| continue | ||
| if not isinstance(fn_args, str): | ||
| try: | ||
| fn_args = json.dumps(fn_args, ensure_ascii=False) | ||
| except TypeError: | ||
| fn_args = str(fn_args) | ||
| normalized.append( | ||
| { | ||
| "id": str(_get_field(call, "id", f"call_{uuid.uuid4().hex[:24]}")), | ||
| "type": "function", | ||
| "function": {"name": str(fn_name), "arguments": fn_args}, | ||
| } | ||
| ) | ||
| return normalized | ||
|
|
||
|
|
||
| def _coerce_arguments_mapping(arguments): | ||
| """Coerce function.arguments to a mapping for HF/Jinja chat templates. | ||
|
|
||
| Examples: | ||
| - {"x": 1} -> {"x": 1} | ||
| - '{"x": 1}' -> {"x": 1} | ||
| - "[1, 2]" -> {} # JSON parses, but not a mapping | ||
| - "not-json" -> {} | ||
| - None -> {} | ||
| """ | ||
| if isinstance(arguments, dict): | ||
| return arguments | ||
| if isinstance(arguments, str): | ||
| try: | ||
| parsed = json.loads(arguments) | ||
| except (TypeError, ValueError): | ||
| return {} | ||
| return parsed if isinstance(parsed, dict) else {} | ||
| return {} | ||
|
|
||
|
|
||
| def _sanitize_messages_for_template(messages): | ||
| """Prepare messages so tokenizer chat templates can safely consume them. | ||
|
|
||
| This only normalizes tool-call argument payloads inside each message: | ||
| - messages[*].tool_calls[*].function.arguments is coerced to a dict. | ||
|
|
||
| Example transformation: | ||
| Input: | ||
| [{"role": "assistant", "tool_calls": [{"function": {"name": "f", "arguments": "{\"x\": 1}"}}]}] | ||
| Output: | ||
| [{"role": "assistant", "tool_calls": [{"function": {"name": "f", "arguments": {"x": 1}}}]}] | ||
|
|
||
| Another example: | ||
| - arguments: "[1,2,3]" -> arguments: {} | ||
| """ | ||
| if not isinstance(messages, list): | ||
| return messages | ||
| sanitized = [] | ||
| for message in messages: | ||
| if not isinstance(message, dict): | ||
| sanitized.append(message) | ||
| continue | ||
| msg_copy = dict(message) | ||
| tool_calls = msg_copy.get("tool_calls") | ||
| if isinstance(tool_calls, list): | ||
| sanitized_tool_calls = [] | ||
| for call in tool_calls: | ||
| if not isinstance(call, dict): | ||
| sanitized_tool_calls.append(call) | ||
| continue | ||
| call_copy = dict(call) | ||
| function = call_copy.get("function") | ||
| if isinstance(function, dict): | ||
| function_copy = dict(function) | ||
| function_copy["arguments"] = _coerce_arguments_mapping(function_copy.get("arguments", {})) | ||
| call_copy["function"] = function_copy | ||
| sanitized_tool_calls.append(call_copy) | ||
| msg_copy["tool_calls"] = sanitized_tool_calls | ||
| sanitized.append(msg_copy) | ||
| return sanitized | ||
|
|
||
|
|
||
| def _sanitize_tools_for_template(tools): | ||
| """Ensure tools payload is template-safe and has mapping parameters. | ||
|
|
||
| Example transformations: | ||
| - {"function": {"name": "f", "parameters": "not-a-dict"}} | ||
| -> {"function": {"name": "f", "parameters": {"type": "object", "properties": {}}}} | ||
| - non-dict tool entries are dropped. | ||
| - non-list input returns None. | ||
| """ | ||
| if not isinstance(tools, list): | ||
| return None | ||
|
|
||
| sanitized = [] | ||
| for tool in tools: | ||
| if not isinstance(tool, dict): | ||
| continue | ||
| tool_copy = dict(tool) | ||
| function = tool_copy.get("function") | ||
| if isinstance(function, dict): | ||
| function_copy = dict(function) | ||
| if not isinstance(function_copy.get("parameters"), dict): | ||
| function_copy["parameters"] = {"type": "object", "properties": {}} | ||
| tool_copy["function"] = function_copy | ||
| sanitized.append(tool_copy) | ||
| return sanitized | ||
|
|
||
| def _replace_prefix_tokens( | ||
| eos_token_id, | ||
| previous_turn_token_ids, | ||
| retokeenized_previous_turn_token_ids, | ||
| current_turn_token_ids | ||
| ): | ||
| """Replace the token ids that are associated with the previous turn with the actual tokens | ||
| from the previous generation (rather than the ones from the chat template application).""" | ||
|
|
||
| # Strip the EOS from the previous turn token ids if it exists | ||
| if previous_turn_token_ids[-1] == eos_token_id: | ||
| previous_turn_token_ids = previous_turn_token_ids[:-1] | ||
|
|
||
| # Find the last EOS token id in the previous turn token ids | ||
| last_eos_token_id_index = len(retokeenized_previous_turn_token_ids) - 1 | ||
| for i in reversed(range(len(retokeenized_previous_turn_token_ids))): | ||
| if current_turn_token_ids[i] == eos_token_id: | ||
| last_eos_token_id_index = i | ||
| break | ||
|
|
||
| # Replace the current turn token ids with the tokens from the previous generation | ||
| current_turn_additional_token_ids = current_turn_token_ids[last_eos_token_id_index:] | ||
|
|
||
| # Return the previous turn token ids + the current turn token ids | ||
| return previous_turn_token_ids + current_turn_additional_token_ids | ||
|
|
||
| try: | ||
| import orjson | ||
|
|
||
|
|
@@ -26,18 +174,31 @@ | |
|
|
||
| bp = Blueprint('chat_completions_api', __name__) | ||
|
|
||
| def apply_parsers(text, tools, parsers_list): | ||
| def apply_parsers(message_text, tools, parsers_list, tools_requested): | ||
| """Runs CPU-intensive text parsing.""" | ||
| meta = {} | ||
| for parser in parsers_list: | ||
| if parser not in PARSER_MAPPING: | ||
| raise ValueError(f"Parser {parser} not found in PARSER_MAPPING") | ||
| text, new_info = PARSER_MAPPING[parser].parse(text, tools=tools) | ||
|
|
||
| prev_text = message_text | ||
| parsed_text, new_info = PARSER_MAPPING[parser].parse( | ||
| message_text, tools=tools | ||
| ) | ||
| if "tool_calls" in new_info: | ||
| new_info["tool_calls"] = _normalize_tool_calls(new_info.get("tool_calls", [])) | ||
| if not tools_requested: | ||
| # Ignore incidental tool-call syntax in plain chat mode. | ||
| parsed_text = prev_text | ||
| new_info.pop("tool_calls", None) | ||
| message_text = parsed_text | ||
|
|
||
| assert not ( | ||
| meta.keys() & new_info.keys() | ||
| ), "Multiple parsers found the same information." | ||
| meta.update(new_info) | ||
| return text, meta | ||
|
|
||
| return message_text, meta | ||
|
|
||
| @bp.route('/chat/completions', methods=['POST']) | ||
| @bp.route('/v1/chat/completions', methods=['POST']) | ||
|
|
@@ -48,42 +209,84 @@ async def chat_completions(): | |
| parsers = current_app.config['parsers'] | ||
|
|
||
| req = await request.get_json() | ||
| tools = req.get("tools", None) | ||
| tools_requested = bool(tools) | ||
| messages = req.get("messages") | ||
| chat_template_kwargs = req.get("chat_template_kwargs", {}) | ||
| if not isinstance(chat_template_kwargs, dict): | ||
| logger.warning("Ignoring non-dict chat_template_kwargs: %s", type(chat_template_kwargs).__name__) | ||
| chat_template_kwargs = {}\ | ||
|
|
||
| # --- 1. Parse Messages --- | ||
| messages = req.get("messages") | ||
| if not messages: | ||
| return Response("Missing 'messages' field", status=400) | ||
| if not isinstance(messages, list): | ||
| return Response("'messages' must be a list", status=400) | ||
|
|
||
| # The OpenAI spec sends tool_call arguments as a JSON string, but | ||
| # Jinja chat templates iterate over them with |items, requiring a dict. | ||
| for msg in messages: | ||
| if msg.get("tool_calls"): | ||
| for tc in msg["tool_calls"]: | ||
| fn = tc.get("function", tc) | ||
| args = fn.get("arguments") | ||
| if isinstance(args, str): | ||
| try: | ||
| fn["arguments"] = json.loads(args) | ||
| except (json.JSONDecodeError, TypeError): | ||
| pass | ||
| template_messages = _sanitize_messages_for_template(messages) | ||
| template_tools = _sanitize_tools_for_template(tools) | ||
|
|
||
| try: | ||
| prompt_tokens = tokenizer.apply_chat_template( | ||
| messages, | ||
| tokenize=True, | ||
| add_generation_prompt=True, | ||
| tools=req.get("tools", None), | ||
| **req.get("chat_template_kwargs", {}), | ||
| ) | ||
| except (AttributeError, AssertionError): | ||
| warnings.warn( | ||
| "Tokenizer does not support 'apply_chat_template'. Using tokenize instead." | ||
| ) | ||
| prompt_tokens = tokenizer.tokenize( | ||
| "\n".join([message["content"] for message in messages]) | ||
| ) | ||
| if hasattr(tokenizer, 'apply_chat_template'): | ||
| prompt_tokens = tokenizer.apply_chat_template( | ||
| template_messages, | ||
| tokenize=True, | ||
| add_generation_prompt=True, | ||
| tools=template_tools, | ||
| **chat_template_kwargs, | ||
| ) | ||
|
|
||
| if req.get("prevent_retokenization", True): | ||
| # If we are avoiding retokenization, we need to replace some prompt tokens with the prompt/generation tokens from the previous generation | ||
| # This improves prefix cache hits and reduces logprob variation between training and inference. | ||
|
|
||
| eos_token_id = tokenizer.eos_id | ||
| assert eos_token_id is not None, "Your tokenizer must have an EOS token ID!" | ||
|
|
||
| warnings.warn( | ||
| "Avoiding prefix retokenization." \ | ||
| " This is a patch that ensures subsequent generations are not retokenized differently than the previous generation." \ | ||
| " This may cause unexpected behavior if messages (including system messages) are altered between generations." | ||
| ) | ||
|
|
||
| # Find the last assistant message | ||
| last_assistant_message_idx = None | ||
| for i in reversed(range(len(template_messages))): | ||
| if template_messages[i]["role"] == "assistant": | ||
| last_assistant_message_idx = i | ||
| break | ||
|
|
||
| # If there was a previous assistant message, we need to replace the prefix tokens with the tokens from the previous generation | ||
| if last_assistant_message_idx is not None: | ||
| messages_to_last_assistant_message = template_messages[: last_assistant_message_idx + 1] | ||
|
|
||
| # Get the templated tokenization of just the previous generation | ||
| retokenized_previous_turn_token_ids = tokenizer.apply_chat_template( | ||
| messages_to_last_assistant_message, | ||
| tokenize=True, | ||
| add_generation_prompt=False, | ||
| tools=template_tools, | ||
| **chat_template_kwargs, | ||
| ) | ||
|
|
||
| # Replace the prefix tokens with the tokens from the previous generation | ||
| last_assistant_message = template_messages[last_assistant_message_idx] | ||
| assert "prompt_token_ids" in last_assistant_message and "generation_token_ids" in last_assistant_message, \ | ||
| "Last assistant message must have prompt_token_ids and generation_token_ids from previous generation to avoid prefix retokenization" | ||
| previous_turn_token_ids = last_assistant_message["prompt_token_ids"] + last_assistant_message["generation_token_ids"] | ||
| prompt_tokens = _replace_prefix_tokens( | ||
| eos_token_id, | ||
| previous_turn_token_ids, | ||
| retokenized_previous_turn_token_ids, | ||
| prompt_tokens, | ||
| ) | ||
|
|
||
| else: | ||
| warnings.warn( | ||
| "Tokenizer does not support 'apply_chat_template'. Using tokenize instead." | ||
| ) | ||
| prompt_tokens = tokenizer.tokenize( | ||
| "\n".join([message["content"] for message in messages]) | ||
| ) | ||
| except Exception as e: | ||
| logger.error(f"{traceback.format_exc()}") | ||
| return Response(f"Error processing 'messages': {e}", status=500) | ||
|
|
@@ -164,7 +367,14 @@ async def chat_completions(): | |
| k: v[1] if isinstance(v, (list, tuple)) and len(v) == 2 and v[0] == "tensor" else v | ||
| for k, v in result.items() | ||
| } | ||
| prompt_tokens_out = result["prompt_tokens"] | ||
|
|
||
| if result["status"] == "FAILED": | ||
| if result["sampling_params"]["num_tokens_to_generate"] <= 0: | ||
| return Response(f"Request {request_idx} failed due to context length overflow", status=400) | ||
| else: | ||
| return Response(f"Request {request_idx} failed due to internal error {result["events"]}", status=500) | ||
|
|
||
| prompt_tokens_out = result["prompt_tokens"] # The engine can modify prompt_tokens. | ||
| text_output = result["generated_text"] | ||
| prompt_tokens_count = len(prompt_tokens_out) if prompt_tokens_out is not None else 0 | ||
| prompt_tokens_counts.append(prompt_tokens_count) | ||
|
|
@@ -208,11 +418,11 @@ async def chat_completions(): | |
|
|
||
| if parsers: | ||
| message_text, metadata = apply_parsers( | ||
| message_text, req.get("tools", None), parsers | ||
| message_text, req.get("tools", None), parsers, tools_requested | ||
| ) | ||
|
|
||
| message = {"role": "assistant", "content": message_text} | ||
| if "tool_calls" in metadata: | ||
| if metadata.get("tool_calls", []): | ||
| message["tool_calls"] = metadata["tool_calls"] | ||
| if "reasoning" in metadata: | ||
| message["reasoning"] = metadata["reasoning"] | ||
|
|
@@ -223,20 +433,9 @@ async def chat_completions(): | |
| message["generation_log_probs"] = result.get("generated_log_probs", []) | ||
| return_log_probs = sampling_params.return_log_probs | ||
|
|
||
| gen_length = result.get("generated_length") or len(result.get("generated_tokens", [])) | ||
| max_gen = result.get("sampling_params", {}) | ||
| if isinstance(max_gen, dict): | ||
| max_gen = max_gen.get("num_tokens_to_generate", None) | ||
| elif hasattr(max_gen, "num_tokens_to_generate"): | ||
| max_gen = max_gen.num_tokens_to_generate | ||
| else: | ||
| max_gen = None | ||
| if metadata.get("tool_calls", []): | ||
| finish_reason = "tool_calls" | ||
| elif max_gen is not None and gen_length >= max_gen: | ||
| finish_reason = "tool_calls" if metadata.get("tool_calls", []) else "stop" | ||
| if len(result["generated_tokens"]) >= result["sampling_params"]["num_tokens_to_generate"]: | ||
| finish_reason = "length" | ||
| else: | ||
| finish_reason = "stop" | ||
|
|
||
| choice_data = { | ||
| "index": request_idx, | ||
|
|
@@ -245,10 +444,10 @@ async def chat_completions(): | |
| "generation_token_ids": result["generated_tokens"], | ||
| "generation_log_probs": result.get("generated_log_probs", []), | ||
| "raw_text": result["prompt"] + result["generated_text"], | ||
| "logprobs": ( | ||
| {"content": logprobs_content} if sampling_params.return_log_probs else None | ||
| ), | ||
| "finish_reason": "tool_calls" if metadata.get("tool_calls", []) else finish_reason, | ||
| # 'logprobs' in chat API is an object containing 'content' | ||
| # "logprobs": {"content": logprobs_content} if logprobs_content else None, | ||
|
Comment on lines
+466
to
+467
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you remove these lines? |
||
| "logprobs": {"content": logprobs_content} if return_log_probs else None, | ||
| "finish_reason": finish_reason, | ||
| } | ||
| choice_data["policy_staleness"] = result["policy_staleness"] | ||
| choice_data["kv_cache_staleness"] = result["kv_cache_staleness"] | ||
|
|
@@ -266,6 +465,8 @@ async def chat_completions(): | |
| ] | ||
|
|
||
| choices.append(choice_data) | ||
| if choice_data["generation_log_probs"] is None: | ||
| print(f"Generation log probs is None for request:\n{json.dumps(result, indent=4)}", flush=True) | ||
|
||
| total_completion_tokens += len(result["generated_tokens"]) | ||
| request_idx += 1 | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can you split this into multiple lines?