Skip to content

Commit 2bfb04a

Browse files
committed
feat(examples,scripts): add micro-batched single invocation multi-model evaluation and interleaved orchestration
1 parent ef2a276 commit 2bfb04a

3 files changed

Lines changed: 49 additions & 28 deletions

File tree

examples/sft/gsm8k/gsm8k_sft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def make_datum(row: dict) -> tinker.Datum:
4040

4141
eval_dataset = load_dataset("openai/gsm8k", "main", split="test[:16]")
4242
return (
43-
SupervisedDatasetFromHFDataset(dataset, self.batch_size, map_fn=make_datum),
44-
SupervisedDatasetFromHFDataset(eval_dataset, self.batch_size, map_fn=make_datum),
43+
SupervisedDatasetFromHFDataset(dataset, self.batch_size, map_fn=make_datum),
44+
SupervisedDatasetFromHFDataset(eval_dataset, self.batch_size, map_fn=make_datum),
4545
)
4646

4747

examples/sft/gsm8k/vllm_eval.py

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,43 +22,59 @@ def main() -> None:
2222
from tinker_cookbook.tokenizer_utils import get_tokenizer
2323

2424
parser = argparse.ArgumentParser()
25-
parser.add_argument("--path", required=True)
25+
parser.add_argument("--path", required=True, action="append", help="One or more URI paths to evaluate concurrently")
2626
parser.add_argument("--base-model", default="Qwen/Qwen2.5-0.5B")
2727
parser.add_argument("--base-url", default=os.getenv("TINKER_BASE_URL", os.getenv("BASE_URL", "http://127.0.0.1:8000")))
2828
parser.add_argument("--data", default="gsm8k_test.json")
2929
parser.add_argument("--gpu-memory-utilization", type=float, default=0.85)
30+
parser.add_argument("--microbatch-size", type=int, default=10, help="Number of evaluation problems to dispatch per micro-batch")
3031
parser.add_argument("--min-accuracy", type=float, default=0.0, help="exit nonzero if accuracy falls below this fraction")
3132
args = parser.parse_args()
3233

3334
with open(args.data) as f:
3435
data = json.load(f)
3536

37+
paths = args.path if isinstance(args.path, list) else [args.path]
3638
client = ServiceClient(api_key=os.getenv("TINKER_API_KEY", "tml-dummy-key"), base_url=args.base_url)
37-
sampler = client.create_sampling_client(args.path)
39+
samplers = [client.create_sampling_client(p) for p in paths]
3840
tokenizer = get_tokenizer(args.base_model)
3941

4042
sampling_params = types.SamplingParams(temperature=0.0, max_tokens=256)
4143
start = time.time()
42-
43-
outputs = []
44-
for datum in data:
45-
prompt_tokens = tokenizer.encode(datum["prompt"], add_special_tokens=False)
46-
seqs = sampler.sample(
47-
prompt=types.ModelInput.from_ints(tokens=prompt_tokens),
48-
num_samples=1,
49-
sampling_params=sampling_params,
50-
).result().sequences
51-
outputs.append(tokenizer.decode(seqs[0].tokens) if seqs else "")
5244

53-
elapsed = time.time() - start
54-
correct = sum(int(extract(text) == datum["gold"]) for datum, text in zip(data, outputs, strict=True))
55-
accuracy = correct / len(data)
45+
import asyncio
46+
47+
async def run_evals():
48+
outputs_by_sampler = [[] for _ in paths]
49+
batch_size = args.microbatch_size
50+
for i in range(0, len(data), batch_size):
51+
chunk = data[i : i + batch_size]
52+
for s_idx, sampler in enumerate(samplers):
53+
tasks = [
54+
sampler.sample_async(
55+
prompt=types.ModelInput.from_ints(tokens=tokenizer.encode(datum["prompt"], add_special_tokens=False)),
56+
num_samples=1,
57+
sampling_params=sampling_params,
58+
)
59+
for datum in chunk
60+
]
61+
res_list = await asyncio.gather(*tasks)
62+
for res in res_list:
63+
seqs = res.sequences
64+
outputs_by_sampler[s_idx].append(tokenizer.decode(seqs[0].tokens) if seqs else "")
65+
return outputs_by_sampler
5666

57-
print("***************************************************************")
58-
print(f"[SAMPLER] {args.path} 0-shot GSM8K acc = {accuracy:.1%} on {len(data)} problems in {elapsed:.1f}s")
59-
print("***************************************************************")
60-
if accuracy < args.min_accuracy:
61-
raise SystemExit(f"GSM8K accuracy {accuracy:.1%} is below the required {args.min_accuracy:.1%}")
67+
outputs_by_sampler = asyncio.run(run_evals())
68+
69+
elapsed = time.time() - start
70+
for path, outputs in zip(paths, outputs_by_sampler, strict=True):
71+
correct = sum(int(extract(text) == datum["gold"]) for datum, text in zip(data, outputs, strict=True))
72+
accuracy = correct / len(data)
73+
print("***************************************************************")
74+
print(f"[SAMPLER] {path} 0-shot GSM8K acc = {accuracy:.1%} on {len(data)} problems in {elapsed:.1f}s")
75+
print("***************************************************************")
76+
if accuracy < args.min_accuracy:
77+
raise SystemExit(f"GSM8K accuracy {accuracy:.1%} for {path} is below the required {args.min_accuracy:.1%}")
6278

6379

6480
if __name__ == "__main__":

scripts/run_training_e2e.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -396,12 +396,15 @@ def run_gsm8k_train(config: RunConfig, base_url: str, watch: list[ManagedProcess
396396
return run_example(config, ["examples/sft/gsm8k/gsm8k_sft.py"], defaults, watch=watch, prefix=prefix)
397397

398398

399-
def run_gsm8k_eval(config: RunConfig, model_path: str) -> None:
399+
def run_gsm8k_eval(config: RunConfig, model_path: str | list[str]) -> None:
400+
paths = model_path if isinstance(model_path, list) else [model_path]
401+
path_args = []
402+
for p in paths:
403+
path_args.extend(["--path", p])
400404
run_command(
401405
["uv", "--project", "examples", "run", "python", "examples/sft/gsm8k/vllm_eval.py"]
406+
+ path_args
402407
+ [
403-
"--path",
404-
model_path,
405408
"--base-url",
406409
config.base_url or "http://127.0.0.1:8000",
407410
"--data",
@@ -472,10 +475,12 @@ def train(job: str) -> None:
472475
raise RuntimeError(f"gsm8k {job} failed") from result
473476

474477
check_snapshot_interleaving(config)
475-
for job, result in sorted(results.items()):
478+
eval_paths = []
479+
for _, result in sorted(results.items()):
476480
assert isinstance(result, str)
477-
print(f"[training-e2e] evaluating {job}")
478-
run_gsm8k_eval(config, resolve_eval_model_path(result))
481+
eval_paths.append(resolve_eval_model_path(result))
482+
print(f"[training-e2e] evaluating jobs in single micro-batched invocation: {eval_paths}")
483+
run_gsm8k_eval(config, eval_paths)
479484

480485

481486
def run_tiny_fft_rl_x2(config: RunConfig, base_url: str, watch: list[ManagedProcess]) -> None:

0 commit comments

Comments
 (0)