Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions docs/en/agentic/oai_endpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,14 @@ Minimal request shape:
"messages": [
{"role": "system", "content": "You are a concise assistant."},
{"role": "user", "content": "Answer with one word: 2+2?"}
],
"logprobs": true,
"return_prompt_token_ids": true
]
}
```

You can pass any OpenAI-compatible parameters in the payload, or any
SGLang-compatible `ChatCompletionRequest` parameters. Note:
`logprobs=True` and `return_prompt_token_ids=True` are set in
`request_kwargs` to extract token ids and logprobs for TITO (see below).
SGLang-compatible `ChatCompletionRequest` parameters. Note: with
`--use-session-server`, the session middleware sets the token-tracking fields
TITO needs and injects Miles-owned `input_ids` before proxying to SGLang.
Do **not** set `logprob_start_len=0` — it would disable SGLang's prefix
cache.

Expand Down Expand Up @@ -129,11 +127,13 @@ request. It requires `--use-session-server`; with that on, miles handles
three things on your behalf:

- Hardcodes the SGLang flags TITO needs on every chat request
(`logprobs=True`, `return_meta_info=True`, `return_prompt_token_ids=True`,
`no_stop_trim=False`); these are set by the middleware in
`miles/rollout/session/sessions.py` and override any agent-passed values.
- Reuses the token prefix from previous turns by injecting `input_ids` on
follow-up requests.
(`logprobs=True`, `return_meta_info=True`, `no_stop_trim=False`); these are
set by the middleware in `miles/rollout/session/sessions.py` and override any
agent-passed values.
- Reuses the token prefix from previous turns by injecting Miles-owned
`input_ids` on every proxied chat request. The response
`choice.prompt_token_ids` is copied from those `input_ids`; it is not read
back from SGLang.
- Accumulates per-turn records into the `Sample` you receive at the end of
the session, with `tokens` and `rollout_log_probs` already populated.

Expand Down
2 changes: 1 addition & 1 deletion miles/rollout/generate_utils/openai_endpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _compute_sample_from_openai_record(
if "prompt_token_ids" in choice:
prompt_token_ids = choice["prompt_token_ids"]
else:
raise ValueError("prompt_token_ids not found in response choice — ensure return_prompt_token_ids=True is set")
raise ValueError("prompt_token_ids not found in response choice — the session server should populate it")

output_token_ids = [item[1] for item in choice["meta_info"]["output_token_logprobs"]]
output_log_probs = [item[0] for item in choice["meta_info"]["output_token_logprobs"]]
Expand Down
24 changes: 2 additions & 22 deletions miles/rollout/session/linear_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@

from miles.rollout.session.session_errors import MessageValidationError, SessionNotFoundError, TokenizationError
from miles.rollout.session.session_types import SessionRecord
from miles.utils.chat_template_utils import (
apply_chat_template,
assert_messages_append_only_with_allowed_role,
message_matches,
)
from miles.utils.chat_template_utils import assert_messages_append_only_with_allowed_role, message_matches
from miles.utils.chat_template_utils.tito_tokenizer import TITOTokenizer

logger = logging.getLogger(__name__)
Expand All @@ -21,19 +17,6 @@
MAX_ASSISTANT_ROLLBACK_STEPS = 1


def _assert_no_user_after_assistant(messages: list[dict[str, Any]]) -> None:
"""Assert no user message appears after the first assistant message."""
seen_assistant = False
for i, msg in enumerate(messages):
role = msg.get("role")
if role == "assistant":
seen_assistant = True
elif role == "user" and seen_assistant:
raise MessageValidationError(
f"invalid message structure: user message at index {i} " f"appears after the first assistant message"
)


@dataclass
class LinearTrajectory:
"""State for a linear trajectory.
Expand Down Expand Up @@ -281,13 +264,10 @@ def compute_session_mismatch(self, session: LinearTrajectory) -> list[dict] | No
return None
try:
tools = session.records[-1].request.get("tools") if session.records else None
expected_ids = apply_chat_template(
expected_ids = self.tito_tokenizer.tokenize_prompt(
session.messages,
tokenizer=self.tokenizer,
tools=tools,
add_generation_prompt=False,
tokenize=True,
**self.tito_tokenizer.chat_template_kwargs,
)
mismatches = self.comparator.compare_sequences(expected_ids, session.token_ids)
return [m.to_dict() for m in mismatches]
Expand Down
38 changes: 25 additions & 13 deletions miles/rollout/session/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,18 @@ async def chat_completions(request: Request, session_id: str):
body = await request.body()
request_body = json.loads(body) if body else {}

# TITO token tracking requires three SGLang flags working together:
# logprobs=True → populates meta_info.output_token_logprobs
# return_prompt_token_ids → adds choice.prompt_token_ids
# return_meta_info → wraps the above in choice.meta_info
# All three are hardcoded (not setdefault) to prevent agent-side
# TITO token tracking requires Miles-owned input_ids plus SGLang
# output-token metadata:
# logprobs=True → populates meta_info.output_token_logprobs
# return_meta_info → wraps the above in choice.meta_info
# Both flags are hardcoded (not set default) to prevent agent-side
# overrides from breaking the token accumulation invariants.
request_body["logprobs"] = True
request_body["return_prompt_token_ids"] = True
request_body["return_meta_info"] = True
if getattr(args, "use_rollout_routing_replay", False):
request_body["return_routed_experts"] = True
# Must be False so stop tokens are trimmed from output: otherwise the
# agent sees stop-token text in content, and the accumulated checkpoint
# would duplicate structural delimiters that the chat template also emits.
# Must be False so stop-token text is trimmed from assistant
# message content; token IDs are still taken from logprobs below.
request_body["no_stop_trim"] = False

request_messages = request_body.get("messages", [])
Expand All @@ -161,10 +159,15 @@ async def chat_completions(request: Request, session_id: str):
)
if pretokenized is not None:
request_body["input_ids"] = pretokenized["input_ids"]
logger.debug(
"Using pretokenized input_ids: %d tokens",
len(pretokenized["input_ids"]),
else:
request_body["input_ids"] = registry.tito_tokenizer.tokenize_prompt(
request_messages,
tools=request_body.get("tools"),
)
logger.debug(
"Using TITO input_ids: %d tokens",
len(request_body["input_ids"]),
)

body = json.dumps(request_body).encode()
expected_num_assistant = session.num_assistant
Expand Down Expand Up @@ -195,7 +198,16 @@ async def chat_completions(request: Request, session_id: str):
"an empty content rather than None. Please check your modified SGLang version."
)

prompt_token_ids = choice.get("prompt_token_ids")
prompt_token_ids = request_body["input_ids"]
# The rollout sample builder still consumes this response field, but
# TITO prompt tokenization is owned by Miles rather than SGLang.
choice["prompt_token_ids"] = prompt_token_ids
result["response_body"] = json.dumps(response).encode()
result["headers"] = {
k: v
for k, v in result["headers"].items()
if k.lower() not in ("content-length", "transfer-encoding", "content-encoding")
}
output_token_logprobs = meta_info["output_token_logprobs"]
completion_tokens = meta_info["completion_tokens"]

Expand Down
86 changes: 86 additions & 0 deletions miles/utils/chat_template_utils/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,81 @@ def assert_messages_append_only_with_allowed_role(
)


def _is_deepseek_v4_tokenizer(tokenizer: Any) -> bool:
# TODO: Is there a better way to check?
names = [getattr(tokenizer, "name_or_path", None)]
init_kwargs = getattr(tokenizer, "init_kwargs", None)
if isinstance(init_kwargs, dict):
names.append(init_kwargs.get("name_or_path"))

for name in names:
if name and "deepseek-v4" in str(name).lower().replace("_", "-"):
return True
return False


def _canonicalize_deepseek_v4_tools(tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Match SGLang's DSv4 OpenAI Tool canonicalization before encoding."""
wrapped = [
tool if isinstance(tool, dict) and "function" in tool else {"type": "function", "function": tool}
for tool in tools
]
validated = TypeAdapter(list[Tool]).validate_python(copy.deepcopy(wrapped))
return [tool.model_dump() for tool in validated]


def _render_deepseek_v4_messages(
messages: list[dict[str, Any]],
*,
add_generation_prompt: bool,
tools: list[dict[str, Any]] | None = None,
**kwargs,
) -> str:
from sglang.srt.entrypoints.openai import encoding_dsv4
from sglang.srt.parser.jinja_template_utils import process_content_for_template_format

remaining_kwargs = dict(kwargs)
thinking = remaining_kwargs.pop("thinking", remaining_kwargs.pop("enable_thinking", False))
reasoning_effort = remaining_kwargs.pop("reasoning_effort", None)
if remaining_kwargs:
raise ValueError(f"Unsupported DeepSeek V4 chat-template kwargs: {sorted(remaining_kwargs)}")

rendered_messages = copy.deepcopy(messages)
for i, msg in enumerate(rendered_messages):
if isinstance(msg.get("content"), list):
rendered_messages[i] = process_content_for_template_format(msg, "string", [], [], [], [])
msg = rendered_messages[i]
if msg.get("role") == "assistant" and msg.get("tool_calls"):
if msg.get("content") is None:
msg["content"] = ""
for tool_call in msg["tool_calls"]:
function = tool_call.get("function", {})
arguments = function.get("arguments")
if isinstance(arguments, dict):
function["arguments"] = json.dumps(arguments, ensure_ascii=False)

if not rendered_messages or rendered_messages[0].get("role") != "system":
rendered_messages.insert(0, {"role": "system", "content": ""})
if tools:
rendered_messages[0]["tools"] = _canonicalize_deepseek_v4_tools(tools)

prompt = encoding_dsv4.encode_messages(
rendered_messages,
thinking_mode="thinking" if thinking else "chat",
reasoning_effort=reasoning_effort if reasoning_effort in ("max", "high") else None,
)
if add_generation_prompt:
return prompt

for suffix in (
encoding_dsv4.ASSISTANT_SP_TOKEN + encoding_dsv4.thinking_start_token,
encoding_dsv4.ASSISTANT_SP_TOKEN + encoding_dsv4.thinking_end_token,
):
if prompt.endswith(suffix):
return prompt[: -len(suffix)]
return prompt
Comment thread
zyzshishui marked this conversation as resolved.


def apply_chat_template(
messages: list[dict],
*,
Expand All @@ -221,6 +296,17 @@ def apply_chat_template(
ensuring the result is ``str`` (tokenize=False) or ``list[int]``
(tokenize=True), not a ``BatchEncoding`` or ``dict``.
"""
if _is_deepseek_v4_tokenizer(tokenizer):
rendered = _render_deepseek_v4_messages(
messages,
add_generation_prompt=add_generation_prompt,
tools=tools,
**kwargs,
)
if tokenize:
return tokenizer.encode(rendered, add_special_tokens=False)
return rendered

messages = _normalize_tool_arguments(messages)
tool_defs = extract_tool_dicts(tools)
render_kwargs = dict(add_generation_prompt=add_generation_prompt, **kwargs)
Expand Down
62 changes: 62 additions & 0 deletions miles/utils/chat_template_utils/tito_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,21 @@ def _render_messages(
def _encode_text(self, text: str) -> list[int]:
return self.tokenizer.encode(text, add_special_tokens=False)

def tokenize_prompt(
self,
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
*,
add_generation_prompt: bool = True,
) -> list[int]:
return self._encode_text(
self._render_messages(
messages,
add_generation_prompt=add_generation_prompt,
tools=tools,
)
)

def _split_appended_segments(self, appended_messages: list[dict[str, Any]]) -> list[list[dict[str, Any]]]:
segments: list[list[dict[str, Any]]] = []
i = 0
Expand Down Expand Up @@ -457,6 +472,51 @@ def merge_tokens(
return prefix + incremental


# ---------------------------------------------------------------------------
# DeepSeek V4 implementation
# ---------------------------------------------------------------------------


class DeepSeekV4TITOTokenizer(TITOTokenizer):
"""DeepSeek V4 — official encoder via SGLang."""

reasoning_parser = "deepseek-v4"
tool_call_parser = "deepseekv4"

SUPPORTED_TEMPLATES = (
FixedTemplateRow(
allowed_roles=frozenset({"tool"}),
template=None,
),
)

_default_assistant_start_str: str = "<|Assistant|>"

def __init__(
self,
tokenizer: Any,
chat_template_kwargs: dict[str, Any] | None = None,
assistant_start_str: str | None = None,
allowed_append_roles: list[str] | None = None,
):
super().__init__(
tokenizer,
chat_template_kwargs,
assistant_start_str or self._default_assistant_start_str,
allowed_append_roles=allowed_append_roles,
)
self._user_id: int = tokenizer.convert_tokens_to_ids("<|User|>")
self._assistant_id: int = tokenizer.convert_tokens_to_ids("<|Assistant|>")

def create_comparator(self) -> TokenSeqComparator:
return TokenSeqComparator(
self.tokenizer,
assistant_start_str=self._assistant_start_str,
special_token_ids={self._user_id, self._assistant_id},
trim_trailing_ids=self.trailing_token_ids or None,
)


# ---------------------------------------------------------------------------
# Enum + Registry + Factory
# ---------------------------------------------------------------------------
Expand All @@ -468,6 +528,7 @@ class TITOTokenizerType(str, Enum):
QWEN35 = "qwen35"
QWENNEXT = "qwennext"
GLM47 = "glm47"
DEEPSEEKV4 = "deepseekv4"


_TOKENIZER_REGISTRY: dict[TITOTokenizerType, type[TITOTokenizer]] = {
Expand All @@ -476,6 +537,7 @@ class TITOTokenizerType(str, Enum):
TITOTokenizerType.QWEN35: Qwen35TITOTokenizer,
TITOTokenizerType.QWENNEXT: QwenNextTITOTokenizer,
TITOTokenizerType.GLM47: GLM47TITOTokenizer,
TITOTokenizerType.DEEPSEEKV4: DeepSeekV4TITOTokenizer,
}


Expand Down
21 changes: 19 additions & 2 deletions miles/utils/dumper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@

import torch
import torch.distributed as dist
from sglang.srt.debug_utils.dumper import DumperConfig, _get_rank, dumper
from sglang.srt.debug_utils.dumper import _get_rank, dumper

# Note: This is a temporary compatibility workaround with SGLang v4 branch
# before it upstreams, will remove this once upstream is done.
try:
from sglang.srt.debug_utils.dumper import DumperConfig
except ImportError:
DumperConfig = None

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -100,6 +107,8 @@ def _configure(args: Namespace, phase: DumperPhase) -> bool:
overrides = _get_phase_override_configs(args, phase)
if not overrides.get("enable"):
return False
if DumperConfig is None:
raise ImportError("The active SGLang dumper module does not expose DumperConfig.")

merged = {
"dir": str(_get_dir(args)),
Expand Down Expand Up @@ -139,7 +148,15 @@ def _cleanup_dump_dir(dump_dir: Path) -> None:

def _get_phase_override_configs(args: Namespace, phase: DumperPhase) -> dict[str, Any]:
raw = getattr(args, f"dumper_{phase.value}")
return {"enable": args.dumper_enable, **DumperConfig._kv_pairs_to_dict(raw)}
return {"enable": args.dumper_enable, **_parse_dumper_kv_pairs(raw)}


def _parse_dumper_kv_pairs(raw: list[str] | None) -> dict[str, Any]:
if not raw:
return {}
if DumperConfig is None:
raise ImportError("The active SGLang dumper module does not expose DumperConfig.")
return DumperConfig._kv_pairs_to_dict(raw)


def _is_phase_enabled(args: Namespace, phase: DumperPhase) -> bool:
Expand Down
Loading
Loading