Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/srtctl/benchmarks/scripts/sa-bench/bench.sh
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ if [ "$DATASET_NAME" = "random" ]; then
--random-input-len "$ISL"
--random-output-len "$OSL"
--random-range-ratio "${RANDOM_RANGE_RATIO}"
# Parallel random prompt generation. Default 48 saturates a 144-core
# GB300 host. Override via env: RANDOM_NUM_WORKERS=8 etc. 0 = auto =
# min(cpu_count, 8). 1 = serial (no multiprocessing).
--random-num-workers "${RANDOM_NUM_WORKERS:-48}"
)
fi

Expand Down
121 changes: 100 additions & 21 deletions src/srtctl/benchmarks/scripts/sa-bench/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from collections.abc import AsyncGenerator, Collection
from dataclasses import dataclass
from datetime import datetime
from multiprocessing import Pool, cpu_count
from typing import Any

import numpy as np
Expand Down Expand Up @@ -356,6 +357,48 @@ def sample_hf_requests(
return sampled_requests


# Worker-side tokenizer set by `_init_random_worker` (one per Pool worker
# process). The serial fallback below also writes here so the worker
# functions can be called identically in both paths.
_worker_tokenizer: PreTrainedTokenizerBase | None = None


def _init_random_worker(tokenizer_id, tokenizer_mode, trust_remote_code, custom_tokenizer):
global _worker_tokenizer
_worker_tokenizer = load_tokenizer(tokenizer_id, tokenizer_mode, trust_remote_code, custom_tokenizer)


def _process_random_no_chat(args):
"""Per-prompt body for the non-chat path — verbatim from the existing
serial loop in ``sample_random_requests``."""
i, offset, input_len, output_len, prefix_token_ids, prefix_len, vocab_size = args
tokenizer = _worker_tokenizer
prompt = tokenizer.decode(
prefix_token_ids + [(offset + i + j) % vocab_size for j in range(input_len)]
)
re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[: (prefix_len + input_len)]
prompt = tokenizer.decode(re_encoded_sequence)
return (prompt, int(prefix_len + input_len), int(output_len), None)


def _process_random_chat(args):
"""Per-prompt body for the chat-template path — verbatim from the
existing serial loop in ``sample_random_requests``."""
i, offset, input_len, output_len, chat_template_len, vocab_size = args
tokenizer = _worker_tokenizer
origin_text = tokenizer.decode(
[(offset + i + j) % vocab_size for j in range(int(input_len * 1.5))]
)
re_encoded_sequence = tokenizer.encode(origin_text, add_special_tokens=False)[: input_len]
prompt_text = tokenizer.decode(re_encoded_sequence)
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
add_generation_prompt=True,
tokenize=False,
)
return (prompt, int(input_len + chat_template_len), int(output_len), None)


def sample_random_requests(
prefix_len: int,
input_len: int,
Expand All @@ -364,6 +407,11 @@ def sample_random_requests(
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
use_chat_template: bool = False,
num_workers: int = 0,
tokenizer_id: str | None = None,
tokenizer_mode: str = "auto",
trust_remote_code: bool = False,
custom_tokenizer: str | None = None,
) -> list[tuple[str, int, int]]:
if use_chat_template:
chat_template_len = len(tokenizer.encode(
Expand All @@ -386,31 +434,49 @@ def sample_random_requests(
size=num_prompts,
)
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
vocab_size = tokenizer.vocab_size

# Build (i, offset, input_len, output_len, ...) tuples to feed the
# per-prompt worker function. Serial and parallel paths run the same
# worker function so behavior is identical in both.
if use_chat_template:
for i in range(num_prompts):
origin_text = tokenizer.decode(
[(offsets[i] + i + j) % tokenizer.vocab_size for j in range(int(input_lens[i] * 1.5))]
)
re_encoded_sequence = tokenizer.encode(origin_text, add_special_tokens=False)[: input_lens[i]]
prompt_text = tokenizer.decode(re_encoded_sequence)
prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt_text}],
add_generation_prompt=True,
tokenize=False,
)
input_lens[i] += chat_template_len
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]), None))
args_list = [
(i, int(offsets[i]), int(input_lens[i]), int(output_lens[i]),
chat_template_len, vocab_size)
for i in range(num_prompts)
]
worker_fn = _process_random_chat
else:
prefix_token_ids = np.random.randint(0, tokenizer.vocab_size, size=prefix_len).tolist()
for i in range(num_prompts):
prompt = tokenizer.decode(
prefix_token_ids + [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]
prefix_token_ids = np.random.randint(0, vocab_size, size=prefix_len).tolist()
args_list = [
(i, int(offsets[i]), int(input_lens[i]), int(output_lens[i]),
prefix_token_ids, prefix_len, vocab_size)
for i in range(num_prompts)
]
worker_fn = _process_random_no_chat

# num_workers <= 0 means auto: cap at 8 (matches sglang/vllm bench defaults).
if num_workers <= 0:
num_workers = min(cpu_count() or 1, 8)
use_parallel = num_workers > 1 and tokenizer_id is not None

if use_parallel:
with Pool(
processes=num_workers,
initializer=_init_random_worker,
initargs=(tokenizer_id, tokenizer_mode, trust_remote_code, custom_tokenizer),
) as pool:
input_requests = pool.map(
worker_fn, args_list,
chunksize=max(1, num_prompts // (num_workers * 4)),
)
re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[: (prefix_len + input_lens[i])]
prompt = tokenizer.decode(re_encoded_sequence)
input_requests.append((prompt, int(prefix_len + input_lens[i]), int(output_lens[i]), None))
else:
# Serial path: reuse the worker function with the parent-process
# tokenizer published into the module global so behavior matches
# the parallel path exactly.
global _worker_tokenizer
_worker_tokenizer = tokenizer
input_requests = [worker_fn(a) for a in args_list]

return input_requests

Expand Down Expand Up @@ -996,6 +1062,11 @@ def main(args: argparse.Namespace):
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
use_chat_template=args.use_chat_template,
num_workers=args.random_num_workers,
tokenizer_id=tokenizer_id,
tokenizer_mode=args.tokenizer_mode,
trust_remote_code=args.trust_remote_code,
custom_tokenizer=args.custom_tokenizer,
)

else:
Expand Down Expand Up @@ -1331,6 +1402,14 @@ def main(args: argparse.Namespace):
action="store_true",
help="Use chat template to format the prompt.",
)
random_group.add_argument(
"--random-num-workers",
type=int,
default=0,
help="Number of worker processes for parallel random prompt "
"generation. Only used with --dataset-name random. "
"0 (default) = auto (min(cpu_count, 8)). 1 = serial.",
)

hf_group = parser.add_argument_group("hf dataset options")
hf_group.add_argument("--hf-subset", type=str, default=None, help="Subset of the HF dataset.")
Expand Down