Skip to content

Commit d36ebc4

Browse files
committed
fix for glm5
fix
1 parent 8294e64 commit d36ebc4

5 files changed

Lines changed: 95 additions & 8 deletions

File tree

src/srtctl/benchmarks/sa_bench.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,5 +97,7 @@ def build_command(
9797
str(prefill_gpus),
9898
str(decode_gpus),
9999
str(b.random_range_ratio) if b.random_range_ratio is not None else "0.8",
100+
b.custom_tokenizer or "",
101+
str(b.use_chat_template).lower(),
100102
]
101103
return cmd

src/srtctl/benchmarks/scripts/sa-bench/backend_request_func.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -511,10 +511,52 @@ def get_model(pretrained_model_name_or_path: str) -> str:
511511
return pretrained_model_name_or_path
512512

513513

514+
def _load_glm_moe_dsa_tokenizer(pretrained_model_name_or_path: str) -> "PreTrainedTokenizerFast":
515+
"""Load GLM-Moe-Dsa / GLM-5 tokenizer directly from tokenizer.json.
516+
517+
Works around incompatibilities when the checkpoint was saved with
518+
transformers 5.x (TokenizersBackend / list-style extra_special_tokens).
519+
"""
520+
import json
521+
from pathlib import Path
522+
523+
from tokenizers import Tokenizer as RustTokenizer
524+
from transformers import PreTrainedTokenizerFast
525+
526+
_SAFE_CONFIG_KEYS = (
527+
"pad_token", "pad_token_id", "eos_token", "eos_token_id",
528+
"bos_token", "bos_token_id", "unk_token", "unk_token_id",
529+
"model_max_length", "padding_side", "truncation_side",
530+
)
531+
532+
path = Path(pretrained_model_name_or_path)
533+
tokenizer_json = path / "tokenizer.json"
534+
if not tokenizer_json.exists():
535+
raise FileNotFoundError(
536+
f"Expected tokenizer.json at {tokenizer_json}. "
537+
"GlmMoeDsaTokenizer loads from tokenizer.json only."
538+
)
539+
540+
rust_tok = RustTokenizer.from_file(str(tokenizer_json))
541+
init_kwargs = {}
542+
config_path = path / "tokenizer_config.json"
543+
if config_path.exists():
544+
with open(config_path, encoding="utf-8") as f:
545+
config = json.load(f)
546+
for key in _SAFE_CONFIG_KEYS:
547+
if key in config:
548+
init_kwargs[key] = config[key]
549+
if "extra_special_tokens" in config:
550+
init_kwargs["additional_special_tokens"] = config["extra_special_tokens"]
551+
552+
return PreTrainedTokenizerFast(tokenizer_object=rust_tok, **init_kwargs)
553+
554+
514555
def get_tokenizer(
515556
pretrained_model_name_or_path: str,
516557
tokenizer_mode: str = "auto",
517558
trust_remote_code: bool = False,
559+
custom_tokenizer: str | None = None,
518560
**kwargs,
519561
) -> PreTrainedTokenizer | PreTrainedTokenizerFast:
520562
if pretrained_model_name_or_path is not None and not os.path.exists(pretrained_model_name_or_path):
@@ -533,12 +575,28 @@ def get_tokenizer(
533575
"to use mistral tokenizer mode."
534576
) from e
535577
return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path))
536-
else:
537-
return AutoTokenizer.from_pretrained(
538-
pretrained_model_name_or_path,
539-
trust_remote_code=trust_remote_code,
540-
**kwargs,
541-
)
578+
if custom_tokenizer:
579+
if custom_tokenizer == "glm_moe_dsa":
580+
return _load_glm_moe_dsa_tokenizer(pretrained_model_name_or_path)
581+
from importlib import import_module
582+
try:
583+
module_path, class_name = custom_tokenizer.rsplit('.', 1)
584+
module = import_module(module_path)
585+
tokenizer_class = getattr(module, class_name)
586+
return tokenizer_class.from_pretrained(
587+
pretrained_model_name_or_path,
588+
trust_remote_code=trust_remote_code,
589+
**kwargs,
590+
)
591+
except (ValueError, ImportError, AttributeError) as e:
592+
raise ValueError(
593+
f"Failed to load custom_tokenizer '{custom_tokenizer}'. "
594+
"Expected 'glm_moe_dsa' or 'module.path.ClassName'.") from e
595+
return AutoTokenizer.from_pretrained(
596+
pretrained_model_name_or_path,
597+
trust_remote_code=trust_remote_code,
598+
**kwargs,
599+
)
542600

543601

544602
ASYNC_REQUEST_FUNCS = {

src/srtctl/benchmarks/scripts/sa-bench/bench.sh

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,20 @@ TOTAL_GPUS=${9:-0}
6060
PREFILL_GPUS=${10:-0}
6161
DECODE_GPUS=${11:-0}
6262
RANDOM_RANGE_RATIO=${12:-0.8}
63+
CUSTOM_TOKENIZER=${13:-}
64+
USE_CHAT_TEMPLATE=${14:-true}
65+
66+
# Build optional custom tokenizer args
67+
CUSTOM_TOKENIZER_ARGS=()
68+
if [ -n "$CUSTOM_TOKENIZER" ]; then
69+
CUSTOM_TOKENIZER_ARGS=(--custom-tokenizer "$CUSTOM_TOKENIZER")
70+
fi
71+
72+
# Build optional chat template args
73+
CHAT_TEMPLATE_ARGS=()
74+
if [ "$USE_CHAT_TEMPLATE" = "true" ]; then
75+
CHAT_TEMPLATE_ARGS=(--use-chat-template)
76+
fi
6377

6478
# Parse endpoint into host:port
6579
HOST=$(echo "$ENDPOINT" | sed 's|http://||' | cut -d: -f1)
@@ -119,7 +133,8 @@ for concurrency in "${CONCURRENCY_LIST[@]}"; do
119133
--request-rate 250 \
120134
--percentile-metrics ttft,tpot,itl,e2el \
121135
--max-concurrency "$concurrency" \
122-
--trust-remote-code
136+
--trust-remote-code \
137+
"${CUSTOM_TOKENIZER_ARGS[@]}"
123138

124139
num_prompts=$((concurrency * 10))
125140

@@ -149,7 +164,8 @@ for concurrency in "${CONCURRENCY_LIST[@]}"; do
149164
--percentile-metrics ttft,tpot,itl,e2el \
150165
--max-concurrency "$concurrency" \
151166
--trust-remote-code \
152-
--use-chat-template \
167+
"${CHAT_TEMPLATE_ARGS[@]}" \
168+
"${CUSTOM_TOKENIZER_ARGS[@]}" \
153169
--save-result --result-dir "$result_dir" --result-filename "$result_filename"
154170
set +x
155171

src/srtctl/benchmarks/scripts/sa-bench/benchmark_serving.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,7 @@ def main(args: argparse.Namespace):
837837
tokenizer_id,
838838
tokenizer_mode=tokenizer_mode,
839839
trust_remote_code=args.trust_remote_code,
840+
custom_tokenizer=args.custom_tokenizer,
840841
)
841842

842843
if args.dataset is not None:
@@ -1279,6 +1280,14 @@ def main(args: argparse.Namespace):
12791280
'"custom" will use --tokenizer to select the preregistered tokenizer.',
12801281
)
12811282

1283+
parser.add_argument(
1284+
"--custom-tokenizer",
1285+
type=str,
1286+
default=None,
1287+
help="Custom tokenizer to use (e.g., 'glm_moe_dsa' or 'module.path.ClassName'). "
1288+
"When set, overrides the default tokenizer loading.",
1289+
)
1290+
12821291
parser.add_argument(
12831292
"--served-model-name",
12841293
type=str,

src/srtctl/core/schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,8 @@ class BenchmarkConfig:
539539
ttft_threshold_ms: int | None = None # Goodput TTFT threshold in ms (default: 2000)
540540
itl_threshold_ms: int | None = None # Goodput ITL threshold in ms (default: 25)
541541
random_range_ratio: float | None = None # Random input/output length range ratio (default: 0.8)
542+
custom_tokenizer: str | None = None # Custom tokenizer class (e.g., "module.path.ClassName")
543+
use_chat_template: bool = True # Pass --use-chat-template to benchmark (default: true)
542544

543545
def get_concurrency_list(self) -> list[int]:
544546
if self.concurrencies is None:

0 commit comments

Comments
 (0)