Skip to content

Commit 7358b79

Browse files
authored
Merge branch 'main' into sysinfochanges
2 parents 1a39d48 + 14ca054 commit 7358b79

11 files changed

Lines changed: 489 additions & 644 deletions

File tree

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Offline ISL (Input Sequence Length) computation for multi-turn datasets.
17+
18+
Run from the repo root to print the ISL distribution for a dataset::
19+
20+
python scripts/agentic_inference_isl_precompute.py \\
21+
--dataset path/to/dataset.jsonl \\
22+
--tokenizer <model-name-or-path>
23+
"""
24+
25+
from __future__ import annotations
26+
27+
import argparse
28+
import logging
29+
import os
30+
import threading
31+
from concurrent.futures import ThreadPoolExecutor, as_completed
32+
33+
import pandas as pd
34+
from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import (
35+
_normalize_tool_calls_for_template,
36+
)
37+
from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset
38+
from tqdm import tqdm
39+
from transformers import AutoTokenizer
40+
41+
logger = logging.getLogger(__name__)
42+
43+
44+
def _precompute_isl(dataloader: MultiTurnDataset, tokenizer_name: str) -> None:
45+
samples_with_messages = [s for s in (dataloader.data or []) if s.get("messages")]
46+
if not samples_with_messages:
47+
return
48+
49+
try:
50+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
51+
except Exception:
52+
logger.exception("Failed to load tokenizer %s", tokenizer_name)
53+
return
54+
55+
first_failure_logged = False
56+
first_failure_lock = threading.Lock()
57+
58+
def _tokenize_sample(sample: dict) -> list[int] | None:
59+
try:
60+
normalized_messages = []
61+
for msg in sample["messages"]:
62+
if msg.get("tool_calls"):
63+
msg = {
64+
**msg,
65+
"tool_calls": _normalize_tool_calls_for_template(
66+
msg["tool_calls"]
67+
),
68+
}
69+
normalized_messages.append(msg)
70+
tools = sample.get("tools")
71+
raw = tokenizer.apply_chat_template(
72+
normalized_messages,
73+
tools=tools if tools else None,
74+
tokenize=True,
75+
add_generation_prompt=True,
76+
)
77+
# Some tokenizers (e.g. Qwen3 fast tokenizer) return BatchEncoding
78+
# instead of a plain list; extract .input_ids in that case.
79+
token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw
80+
return token_ids
81+
except Exception:
82+
nonlocal first_failure_logged
83+
with first_failure_lock:
84+
if not first_failure_logged:
85+
logger.exception("apply_chat_template failed (first failure shown)")
86+
first_failure_logged = True
87+
return None
88+
89+
n_workers = min(os.cpu_count() or 32, 32)
90+
skipped = 0
91+
with ThreadPoolExecutor(
92+
max_workers=n_workers, thread_name_prefix="ISLPrecompute"
93+
) as pool:
94+
futures = {
95+
pool.submit(_tokenize_sample, sample): sample
96+
for sample in samples_with_messages
97+
}
98+
for future in tqdm(
99+
as_completed(futures),
100+
total=len(futures),
101+
desc="Pre-computing ISL",
102+
unit="turn",
103+
):
104+
sample = futures[future]
105+
token_ids = future.result()
106+
if token_ids is not None:
107+
sample["input_tokens"] = token_ids
108+
else:
109+
skipped += 1
110+
111+
if skipped:
112+
logger.warning("%d turn(s) skipped (apply_chat_template failed)", skipped)
113+
if skipped == len(samples_with_messages):
114+
logger.warning(
115+
"All %d turn(s) failed apply_chat_template. "
116+
"Check tokenizer/template compatibility.",
117+
len(samples_with_messages),
118+
)
119+
120+
121+
def _isl_distribution(dataloader: MultiTurnDataset) -> dict[str, float]:
122+
values = sorted(
123+
len(s["input_tokens"])
124+
for s in (dataloader.data or [])
125+
if s.get("input_tokens") is not None
126+
)
127+
if not values:
128+
raise ValueError(
129+
"No input_tokens found — tokenization may have failed entirely."
130+
)
131+
n = len(values)
132+
133+
def percentile(p: float) -> float:
134+
idx = (p / 100) * (n - 1)
135+
lo, frac = int(idx), idx % 1
136+
return values[lo] + frac * (values[lo + 1] - values[lo] if lo + 1 < n else 0)
137+
138+
return {
139+
"min": values[0],
140+
"max": values[-1],
141+
"mean": sum(values) / n,
142+
"p50": percentile(50),
143+
"p99": percentile(99),
144+
}
145+
146+
147+
def main() -> None:
148+
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
149+
150+
parser = argparse.ArgumentParser(
151+
description="Compute ISL distribution for a multi-turn dataset."
152+
)
153+
parser.add_argument("--dataset", required=True, help="Path to JSONL dataset file.")
154+
parser.add_argument(
155+
"--tokenizer", required=True, help="HuggingFace repo ID or local path."
156+
)
157+
args = parser.parse_args()
158+
159+
ds = MultiTurnDataset(pd.read_json(args.dataset, lines=True))
160+
ds.load()
161+
_precompute_isl(ds, args.tokenizer)
162+
163+
stats = _isl_distribution(ds)
164+
n = sum(1 for s in (ds.data or []) if s.get("input_tokens") is not None)
165+
print(f"ISL distribution ({n} turns)")
166+
print(f" min : {stats['min']:.0f}")
167+
print(f" mean : {stats['mean']:.1f}")
168+
print(f" p50 : {stats['p50']:.0f}")
169+
print(f" p99 : {stats['p99']:.0f}")
170+
print(f" max : {stats['max']:.0f}")
171+
172+
173+
if __name__ == "__main__":
174+
main()

src/inference_endpoint/commands/benchmark/execute.py

Lines changed: 12 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
import msgspec.json
4444
from huggingface_hub import model_info
4545
from tqdm import tqdm
46-
from transformers import AutoTokenizer
4746
from transformers.utils import logging as transformers_logging
4847

4948
from inference_endpoint.async_utils.event_publisher import EventPublisherService
@@ -58,9 +57,6 @@
5857
from inference_endpoint.async_utils.services.metrics_aggregator.subscriber import (
5958
MetricsSnapshotSubscriber,
6059
)
61-
from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import (
62-
_normalize_tool_calls_for_template,
63-
)
6460
from inference_endpoint.async_utils.transport.zmq.context import ManagedZMQContext
6561
from inference_endpoint.config.runtime_settings import RuntimeSettings
6662
from inference_endpoint.config.schema import (
@@ -314,75 +310,6 @@ def _load_datasets(
314310
return dataloader, accuracy_datasets, eval_configs
315311

316312

317-
def _precompute_isl_for_multi_turn(
318-
dataloader: MultiTurnDataset, tokenizer_name: str
319-
) -> None:
320-
"""Tokenize pre-built message lists and store token counts in each sample.
321-
322-
Runs apply_chat_template once per client turn so the hot-path IslTrigger
323-
sync path (len(token_ids)) is used instead of on-the-fly text tokenization.
324-
Only affects dataset-history turns; live-history turns override 'messages'
325-
at runtime so the stored input_tokens are stale (acceptable approximation).
326-
"""
327-
try:
328-
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
329-
except Exception:
330-
logger.exception(
331-
"ISL pre-computation: failed to load tokenizer %s; "
332-
"falling back to text-tokenization at runtime",
333-
tokenizer_name,
334-
)
335-
return
336-
skipped = 0
337-
first_failure_logged = False
338-
for sample in dataloader.data or []:
339-
messages = sample.get("messages")
340-
if not messages:
341-
continue
342-
try:
343-
normalized_messages = []
344-
for msg in messages:
345-
if msg.get("tool_calls"):
346-
msg = {
347-
**msg,
348-
"tool_calls": _normalize_tool_calls_for_template(
349-
msg["tool_calls"]
350-
),
351-
}
352-
normalized_messages.append(msg)
353-
tools = sample.get("tools")
354-
raw = tokenizer.apply_chat_template(
355-
normalized_messages,
356-
tools=tools if tools else None,
357-
tokenize=True,
358-
add_generation_prompt=True,
359-
)
360-
# Some tokenizers (e.g. Qwen3 fast tokenizer) return BatchEncoding
361-
# instead of a plain list; extract .input_ids in that case.
362-
token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw
363-
sample["input_tokens"] = token_ids
364-
except Exception:
365-
if not first_failure_logged:
366-
logger.exception(
367-
"ISL pre-computation: apply_chat_template failed (first failure shown)"
368-
)
369-
first_failure_logged = True
370-
skipped += 1
371-
if skipped:
372-
logger.warning(
373-
"ISL pre-computation: %d turn(s) skipped (apply_chat_template failed)",
374-
skipped,
375-
)
376-
total_with_messages = len([s for s in (dataloader.data or []) if s.get("messages")])
377-
if total_with_messages > 0 and skipped == total_with_messages:
378-
logger.warning(
379-
"ISL precomputation: all %d turn(s) failed apply_chat_template; "
380-
"ISL metrics will use text-tokenization fallback. "
381-
"Check tokenizer/template compatibility.",
382-
total_with_messages,
383-
)
384-
385-
386313
def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkContext:
387314
"""Load tokenizer, dataset, create scheduler, setup report dir."""
388315
# CPU affinity
@@ -401,7 +328,18 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo
401328

402329
# Tokenizer check (light API call, no download)
403330
model_name = config.model_params.name
404-
tokenizer_name = model_name if _check_tokenizer_exists(model_name) else None
331+
tokenizer_override = config.model_params.tokenizer_name
332+
tokenizer_name: str | None
333+
if tokenizer_override:
334+
if not _check_tokenizer_exists(tokenizer_override):
335+
raise SetupError(
336+
f"Tokenizer override '{tokenizer_override}' could not be verified. "
337+
"Check that the HF repo ID or local path is correct, accessible, and contains tokenizer files. "
338+
"See logs above for details."
339+
)
340+
tokenizer_name = tokenizer_override
341+
else:
342+
tokenizer_name = model_name if _check_tokenizer_exists(model_name) else None
405343

406344
# Streaming
407345
logger.info(
@@ -412,10 +350,6 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo
412350
# Datasets
413351
dataloader, accuracy_datasets, eval_configs = _load_datasets(config, report_dir)
414352

415-
if isinstance(dataloader, MultiTurnDataset) and tokenizer_name is not None:
416-
logger.info("Pre-computing ISL token counts for multi-turn dataset…")
417-
_precompute_isl_for_multi_turn(dataloader, tokenizer_name)
418-
419353
# Setup runtime settings using factory method
420354
rt_settings = RuntimeSettings.from_config(config, dataloader.num_samples())
421355

src/inference_endpoint/config/schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,13 @@ class ModelParams(BaseModel):
205205
StreamingMode,
206206
cyclopts.Parameter(alias="--streaming", help="Streaming mode: auto/on/off"),
207207
] = StreamingMode.AUTO
208+
tokenizer_name: Annotated[
209+
str | None,
210+
cyclopts.Parameter(
211+
alias="--tokenizer",
212+
help="HF repo ID or local path for the tokenizer. Overrides model name for client-side token metrics (ISL/OSL/TPOT).",
213+
),
214+
] = None
208215

209216

210217
class SubmissionReference(BaseModel):

src/inference_endpoint/config/templates/concurrency_template_full.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ model_params:
1515
max_new_tokens: 1024 # Max output tokens
1616
osl_distribution: null # Output sequence length distribution
1717
streaming: 'on' # Streaming mode: auto/on/off | options: auto, on, off
18+
tokenizer_name: null # HF repo ID or local path for the tokenizer. Overrides model name for client-side token metrics (ISL/OSL/TPOT).
1819
datasets: # Dataset configs
1920
- name: perf
2021
type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy

src/inference_endpoint/config/templates/offline_template_full.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ model_params:
1515
max_new_tokens: 1024 # Max output tokens
1616
osl_distribution: null # Output sequence length distribution
1717
streaming: 'off' # Streaming mode: auto/on/off | options: auto, on, off
18+
tokenizer_name: null # HF repo ID or local path for the tokenizer. Overrides model name for client-side token metrics (ISL/OSL/TPOT).
1819
datasets: # Dataset configs
1920
- name: perf
2021
type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy

src/inference_endpoint/config/templates/online_template_full.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ model_params:
1515
max_new_tokens: 1024 # Max output tokens
1616
osl_distribution: null # Output sequence length distribution
1717
streaming: 'on' # Streaming mode: auto/on/off | options: auto, on, off
18+
tokenizer_name: null # HF repo ID or local path for the tokenizer. Overrides model name for client-side token metrics (ISL/OSL/TPOT).
1819
datasets: # Dataset configs
1920
- name: perf
2021
type: performance # Dataset purpose: performance or accuracy | options: performance, accuracy

0 commit comments

Comments
 (0)