|
| 1 | +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +# SPDX-License-Identifier: Apache-2.0 |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +"""Offline ISL (Input Sequence Length) computation for multi-turn datasets. |
| 17 | +
|
| 18 | +Run from the repo root to print the ISL distribution for a dataset:: |
| 19 | +
|
| 20 | + python scripts/agentic_inference_isl_precompute.py \\ |
| 21 | + --dataset path/to/dataset.jsonl \\ |
| 22 | + --tokenizer <model-name-or-path> |
| 23 | +""" |
| 24 | + |
| 25 | +from __future__ import annotations |
| 26 | + |
| 27 | +import argparse |
| 28 | +import logging |
| 29 | +import os |
| 30 | +import threading |
| 31 | +from concurrent.futures import ThreadPoolExecutor, as_completed |
| 32 | + |
| 33 | +import pandas as pd |
| 34 | +from inference_endpoint.async_utils.services.metrics_aggregator.token_metrics import ( |
| 35 | + _normalize_tool_calls_for_template, |
| 36 | +) |
| 37 | +from inference_endpoint.dataset_manager.multi_turn_dataset import MultiTurnDataset |
| 38 | +from tqdm import tqdm |
| 39 | +from transformers import AutoTokenizer |
| 40 | + |
| 41 | +logger = logging.getLogger(__name__) |
| 42 | + |
| 43 | + |
| 44 | +def _precompute_isl(dataloader: MultiTurnDataset, tokenizer_name: str) -> None: |
| 45 | + samples_with_messages = [s for s in (dataloader.data or []) if s.get("messages")] |
| 46 | + if not samples_with_messages: |
| 47 | + return |
| 48 | + |
| 49 | + try: |
| 50 | + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
| 51 | + except Exception: |
| 52 | + logger.exception("Failed to load tokenizer %s", tokenizer_name) |
| 53 | + return |
| 54 | + |
| 55 | + first_failure_logged = False |
| 56 | + first_failure_lock = threading.Lock() |
| 57 | + |
| 58 | + def _tokenize_sample(sample: dict) -> list[int] | None: |
| 59 | + try: |
| 60 | + normalized_messages = [] |
| 61 | + for msg in sample["messages"]: |
| 62 | + if msg.get("tool_calls"): |
| 63 | + msg = { |
| 64 | + **msg, |
| 65 | + "tool_calls": _normalize_tool_calls_for_template( |
| 66 | + msg["tool_calls"] |
| 67 | + ), |
| 68 | + } |
| 69 | + normalized_messages.append(msg) |
| 70 | + tools = sample.get("tools") |
| 71 | + raw = tokenizer.apply_chat_template( |
| 72 | + normalized_messages, |
| 73 | + tools=tools if tools else None, |
| 74 | + tokenize=True, |
| 75 | + add_generation_prompt=True, |
| 76 | + ) |
| 77 | + # Some tokenizers (e.g. Qwen3 fast tokenizer) return BatchEncoding |
| 78 | + # instead of a plain list; extract .input_ids in that case. |
| 79 | + token_ids: list[int] = raw.input_ids if hasattr(raw, "input_ids") else raw |
| 80 | + return token_ids |
| 81 | + except Exception: |
| 82 | + nonlocal first_failure_logged |
| 83 | + with first_failure_lock: |
| 84 | + if not first_failure_logged: |
| 85 | + logger.exception("apply_chat_template failed (first failure shown)") |
| 86 | + first_failure_logged = True |
| 87 | + return None |
| 88 | + |
| 89 | + n_workers = min(os.cpu_count() or 32, 32) |
| 90 | + skipped = 0 |
| 91 | + with ThreadPoolExecutor( |
| 92 | + max_workers=n_workers, thread_name_prefix="ISLPrecompute" |
| 93 | + ) as pool: |
| 94 | + futures = { |
| 95 | + pool.submit(_tokenize_sample, sample): sample |
| 96 | + for sample in samples_with_messages |
| 97 | + } |
| 98 | + for future in tqdm( |
| 99 | + as_completed(futures), |
| 100 | + total=len(futures), |
| 101 | + desc="Pre-computing ISL", |
| 102 | + unit="turn", |
| 103 | + ): |
| 104 | + sample = futures[future] |
| 105 | + token_ids = future.result() |
| 106 | + if token_ids is not None: |
| 107 | + sample["input_tokens"] = token_ids |
| 108 | + else: |
| 109 | + skipped += 1 |
| 110 | + |
| 111 | + if skipped: |
| 112 | + logger.warning("%d turn(s) skipped (apply_chat_template failed)", skipped) |
| 113 | + if skipped == len(samples_with_messages): |
| 114 | + logger.warning( |
| 115 | + "All %d turn(s) failed apply_chat_template. " |
| 116 | + "Check tokenizer/template compatibility.", |
| 117 | + len(samples_with_messages), |
| 118 | + ) |
| 119 | + |
| 120 | + |
| 121 | +def _isl_distribution(dataloader: MultiTurnDataset) -> dict[str, float]: |
| 122 | + values = sorted( |
| 123 | + len(s["input_tokens"]) |
| 124 | + for s in (dataloader.data or []) |
| 125 | + if s.get("input_tokens") is not None |
| 126 | + ) |
| 127 | + if not values: |
| 128 | + raise ValueError( |
| 129 | + "No input_tokens found — tokenization may have failed entirely." |
| 130 | + ) |
| 131 | + n = len(values) |
| 132 | + |
| 133 | + def percentile(p: float) -> float: |
| 134 | + idx = (p / 100) * (n - 1) |
| 135 | + lo, frac = int(idx), idx % 1 |
| 136 | + return values[lo] + frac * (values[lo + 1] - values[lo] if lo + 1 < n else 0) |
| 137 | + |
| 138 | + return { |
| 139 | + "min": values[0], |
| 140 | + "max": values[-1], |
| 141 | + "mean": sum(values) / n, |
| 142 | + "p50": percentile(50), |
| 143 | + "p99": percentile(99), |
| 144 | + } |
| 145 | + |
| 146 | + |
| 147 | +def main() -> None: |
| 148 | + logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s") |
| 149 | + |
| 150 | + parser = argparse.ArgumentParser( |
| 151 | + description="Compute ISL distribution for a multi-turn dataset." |
| 152 | + ) |
| 153 | + parser.add_argument("--dataset", required=True, help="Path to JSONL dataset file.") |
| 154 | + parser.add_argument( |
| 155 | + "--tokenizer", required=True, help="HuggingFace repo ID or local path." |
| 156 | + ) |
| 157 | + args = parser.parse_args() |
| 158 | + |
| 159 | + ds = MultiTurnDataset(pd.read_json(args.dataset, lines=True)) |
| 160 | + ds.load() |
| 161 | + _precompute_isl(ds, args.tokenizer) |
| 162 | + |
| 163 | + stats = _isl_distribution(ds) |
| 164 | + n = sum(1 for s in (ds.data or []) if s.get("input_tokens") is not None) |
| 165 | + print(f"ISL distribution ({n} turns)") |
| 166 | + print(f" min : {stats['min']:.0f}") |
| 167 | + print(f" mean : {stats['mean']:.1f}") |
| 168 | + print(f" p50 : {stats['p50']:.0f}") |
| 169 | + print(f" p99 : {stats['p99']:.0f}") |
| 170 | + print(f" max : {stats['max']:.0f}") |
| 171 | + |
| 172 | + |
| 173 | +if __name__ == "__main__": |
| 174 | + main() |
0 commit comments