Skip to content

Commit d896832

Browse files
committed
Add training e2e checks
1 parent 74e83f7 commit d896832

15 files changed

Lines changed: 1017 additions & 32 deletions

File tree

.dockerignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
.git
2-
.venv
3-
examples/.venv
2+
**/.venv
43
**/__pycache__
54
**/*.pyc
65
.ruff_cache

Makefile

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,32 @@ HOST ?= 127.0.0.1
1313
PORT ?= 9003
1414
# The fully qualified base URL used by local CLI tools and clients
1515
BASE_URL ?= http://$(HOST):$(PORT)
16-
TEST_PYTHONPATH ?= examples/sft/pig-latin
16+
UNIT_TESTS ?= tests.test_gateway_paths tests.test_lora_targets tests.test_snapshot_agent tests.test_trainer_optimizer_correctness tests.test_worker_launch_processor
17+
# Only forward BASE_URL to e2e when the user supplied it. The Makefile default
18+
# is for local CLI usage; e2e should start its own backend by default.
19+
TRAINING_TEST_BASE_URL ?= $(if $(filter environment command line,$(origin BASE_URL)),$(BASE_URL),)
20+
TRAINING_TEST_EXTRA ?= gpu
21+
TRAINING_TEST_ARGS ?=
22+
PIGLATIN_TEST_PYTHONPATH ?= examples/sft/pig-latin
23+
24+
# CUDA_VISIBLE_DEVICES can be provided either as an environment variable or as a
25+
# Make variable, and is inherited by the backend/eval subprocesses.
26+
ifneq ($(origin CUDA_VISIBLE_DEVICES),undefined)
27+
export CUDA_VISIBLE_DEVICES
28+
endif
1729

1830
help:
1931
@echo "make server # $(BASE_MODEL), SAMPLING_BACKEND=$(SAMPLING_BACKEND), port $(PORT)"
2032
@echo "make server BASE_MODEL=google/gemma-4-e2b SAMPLING_BACKEND=vllm"
2133
@echo "VLLM_ARCHITECTURE_OVERRIDE=Gemma4ForCausalLM make vllm BASE_MODEL=google/gemma-4-e2b"
22-
@echo "make test | lint | fmt"
34+
@echo "make test # fast unit tests"
35+
@echo "make test e2e tiny-lora|tiny-fft|tiny-rl|lora-textsql|fft-gsm8k|fft-gsm8k-x2 # tiny-* = fast overfit smoke tests"
36+
@echo "make test e2e tiny-lora BASE_URL=http://host:9003"
37+
@echo "CUDA_VISIBLE_DEVICES=0 make test e2e tiny-fft"
38+
@echo "make test e2e tiny-fft TRAINING_TEST_ARGS='steps=20'"
39+
@echo "make test e2e fft-gsm8k TRAINING_TEST_ARGS='steps=10 eval_examples=8 extra=\"batch=2\"'"
40+
@echo "make test piglatin # pig-latin example end-to-end tests"
41+
@echo "make lint | fmt"
2342

2443
# ---------------------------------------------------------------------------
2544
# Server
@@ -42,14 +61,40 @@ ifeq (cli,$(firstword $(MAKECMDGOALS)))
4261
$(eval $(CLI_ARGS):;@:)
4362
endif
4463

64+
ifeq (test,$(firstword $(MAKECMDGOALS)))
65+
TEST_MODE := $(word 2,$(MAKECMDGOALS))
66+
TEST_SCENARIO := $(word 3,$(MAKECMDGOALS))
67+
TEST_ARGS := $(wordlist 2,$(words $(MAKECMDGOALS)),$(MAKECMDGOALS))
68+
ifneq ($(TEST_ARGS),)
69+
$(eval $(TEST_ARGS):;@:)
70+
endif
71+
endif
72+
4573
cli:
4674
@cd dev/tools && BASE_URL="$(BASE_URL)" uv run python cli.py $(CLI_ARGS)
4775

4876
# ---------------------------------------------------------------------------
4977
# Dev
5078
# ---------------------------------------------------------------------------
5179
test:
52-
PYTHONPATH="$(TEST_PYTHONPATH)" uv --project examples run python -m unittest discover -s tests
80+
@mode="$(TEST_MODE)"; \
81+
scenario="$(TEST_SCENARIO)"; \
82+
if [ -z "$$mode" ] || [ "$$mode" = "unit" ]; then \
83+
uv run --frozen --exact --extra cpu python -m unittest $(UNIT_TESTS); \
84+
elif [ "$$mode" = "e2e" ]; then \
85+
if [ -z "$$scenario" ]; then \
86+
echo "Missing e2e scenario. Expected tiny-lora, tiny-fft, tiny-rl, lora-textsql, fft-gsm8k, or fft-gsm8k-x2."; \
87+
exit 2; \
88+
fi; \
89+
set -- "scenario=$$scenario" "uv_extra=$(TRAINING_TEST_EXTRA)"; \
90+
if [ -n "$(TRAINING_TEST_BASE_URL)" ]; then set -- "$$@" "base_url=$(TRAINING_TEST_BASE_URL)"; fi; \
91+
uv run --extra "$(TRAINING_TEST_EXTRA)" python scripts/run_training_e2e.py "$$@" $(TRAINING_TEST_ARGS); \
92+
elif [ "$$mode" = "piglatin" ]; then \
93+
PYTHONPATH="$(PIGLATIN_TEST_PYTHONPATH)" uv --project examples run python -m unittest discover -s tests; \
94+
else \
95+
echo "Unknown test mode '$$mode'. Expected unit, e2e, or piglatin."; \
96+
exit 2; \
97+
fi
5398

5499
lint:
55100
uvx ruff check .

examples/sft/gsm8k/vllm_eval.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,32 @@ def main() -> None:
2222
parser = argparse.ArgumentParser()
2323
parser.add_argument("--path", required=True)
2424
parser.add_argument("--data", default="gsm8k_test.json")
25+
parser.add_argument("--gpu-memory-utilization", type=float, default=0.85)
26+
parser.add_argument("--min-accuracy", type=float, default=0.0, help="exit nonzero if accuracy falls below this fraction")
2527
args = parser.parse_args()
2628

2729
with open(args.data) as f:
2830
data = json.load(f)
2931

30-
llm = LLM(model=args.path, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=1024, enforce_eager=True)
32+
llm = LLM(
33+
model=args.path,
34+
dtype="bfloat16",
35+
gpu_memory_utilization=args.gpu_memory_utilization,
36+
max_model_len=1024,
37+
enforce_eager=True,
38+
)
3139
sampling_params = SamplingParams(temperature=0.0, max_tokens=256, stop=["\nQuestion:"])
3240
start = time.time()
3341
outputs = llm.generate([datum["prompt"] for datum in data], sampling_params)
3442
elapsed = time.time() - start
3543
correct = sum(int(extract(output.outputs[0].text) == datum["gold"]) for datum, output in zip(data, outputs, strict=True))
44+
accuracy = correct / len(data)
3645

3746
print("***************************************************************")
38-
print(f"[VLLM] {args.path} 0-shot GSM8K acc = {correct / len(data):.1%} on {len(data)} problems in {elapsed:.1f}s")
47+
print(f"[VLLM] {args.path} 0-shot GSM8K acc = {accuracy:.1%} on {len(data)} problems in {elapsed:.1f}s")
3948
print("***************************************************************")
49+
if accuracy < args.min_accuracy:
50+
raise SystemExit(f"GSM8K accuracy {accuracy:.1%} is below the required {args.min_accuracy:.1%}")
4051

4152

4253
if __name__ == "__main__":

examples/tiny/tiny_rl.py

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
"""Tiny RL smoke test: sample from the current policy, reward completions that
2+
contain the target answer, and run a few importance-sampling policy-gradient steps.
3+
4+
uv --project examples run python examples/tiny/tiny_rl.py base_url=http://127.0.0.1:9003
5+
"""
6+
7+
from __future__ import annotations
8+
9+
import json
10+
import math
11+
import os
12+
import shutil
13+
import statistics
14+
from pathlib import Path
15+
from typing import Any, cast
16+
17+
import chz
18+
import tinker
19+
from tinker import types
20+
21+
BASE_URL = "http://127.0.0.1:9003"
22+
23+
os.environ.setdefault("TINKER_API_KEY", "tml-dummy-key")
24+
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
25+
26+
27+
@chz.chz
28+
class Config:
29+
base_model: str = "Qwen/Qwen2.5-0.5B"
30+
base_url: str = os.getenv("TINKER_BASE_URL", os.getenv("BASE_URL", BASE_URL))
31+
log_dir: str = str(Path(__file__).with_name("artifacts") / "tiny_rl")
32+
prompt: str = "Question: What is 2 + 2?\nAnswer:"
33+
target: str = "4"
34+
steps: int = 2
35+
samples_per_prompt: int = 8
36+
max_tokens: int = 16
37+
temperature: float = 1.0
38+
learning_rate: float = 1e-5
39+
grad_clip_norm: float = 1.0
40+
loss_fn: str = "importance_sampling"
41+
rank: int = 16
42+
seed: int = 0
43+
behavior_if_log_dir_exists: str = "delete"
44+
45+
46+
def reset_log_dir(path: Path, behavior: str) -> None:
47+
if not path.exists():
48+
path.mkdir(parents=True)
49+
return
50+
if behavior == "delete":
51+
shutil.rmtree(path)
52+
path.mkdir(parents=True)
53+
return
54+
if behavior == "error":
55+
raise RuntimeError(f"Log directory already exists: {path}")
56+
raise ValueError(f"Unsupported behavior_if_log_dir_exists={behavior!r}")
57+
58+
59+
def write_metric(log_dir: Path, row: dict[str, Any]) -> None:
60+
with (log_dir / "metrics.jsonl").open("a", encoding="utf-8") as f:
61+
f.write(json.dumps(row, sort_keys=True) + "\n")
62+
63+
64+
def build_datum(prompt_tokens: list[int], completion_tokens: list[int], logprobs: list[float], advantage: float) -> types.Datum:
65+
tokens = prompt_tokens + completion_tokens
66+
prompt_pad = [0.0] * (len(prompt_tokens) - 1)
67+
return types.Datum(
68+
model_input=types.ModelInput.from_ints(tokens=tokens[:-1]),
69+
loss_fn_inputs=cast(
70+
Any,
71+
{
72+
"target_tokens": tokens[1:],
73+
"weights": prompt_pad + [1.0] * len(completion_tokens),
74+
"logprobs": prompt_pad + logprobs,
75+
"advantages": prompt_pad + [advantage] * len(completion_tokens),
76+
},
77+
),
78+
)
79+
80+
81+
def main(config: Config) -> None:
82+
if config.steps < 1:
83+
raise ValueError("Tiny RL needs steps >= 1")
84+
log_dir = Path(config.log_dir)
85+
reset_log_dir(log_dir, config.behavior_if_log_dir_exists)
86+
87+
client = tinker.ServiceClient(api_key=os.getenv("TINKER_API_KEY", "tml-dummy-key"), base_url=config.base_url)
88+
trainer = client.create_lora_training_client(
89+
base_model=config.base_model,
90+
rank=config.rank,
91+
seed=config.seed,
92+
train_attn=True,
93+
train_mlp=True,
94+
# Qwen2.5-0.5B ties lm_head to embed_tokens; LoRA on the tied head trips a
95+
# PEFT warning and vLLM cannot load lm_head adapter weights at all.
96+
train_unembed=False,
97+
)
98+
tokenizer = trainer.get_tokenizer()
99+
prompt_tokens = tokenizer.encode(config.prompt, add_special_tokens=False)
100+
prompt = types.ModelInput.from_ints(tokens=prompt_tokens)
101+
sampling_params = types.SamplingParams(max_tokens=config.max_tokens, temperature=config.temperature)
102+
103+
mean_reward = 0.0
104+
for step in range(1, config.steps + 1):
105+
sampler = trainer.save_weights_and_get_sampling_client()
106+
sequences = sampler.sample(prompt=prompt, num_samples=config.samples_per_prompt, sampling_params=sampling_params).result().sequences
107+
108+
rewards = []
109+
for sequence in sequences:
110+
tokens, logprobs = list(sequence.tokens), list(sequence.logprobs or [])
111+
if not tokens or len(tokens) != len(logprobs):
112+
raise RuntimeError(f"Sampler must return aligned tokens and logprobs, got {len(tokens)} tokens and {len(logprobs)} logprobs")
113+
rewards.append(1.0 if config.target in tokenizer.decode(tokens) else 0.0)
114+
115+
# Group-centered advantages; when every reward ties, fall back to a uniform
116+
# positive advantage so the update still exercises a nonzero gradient.
117+
mean_reward = statistics.fmean(rewards)
118+
advantages = [reward - mean_reward for reward in rewards]
119+
if all(abs(advantage) < 1e-8 for advantage in advantages):
120+
advantages = [1.0] * len(rewards)
121+
122+
datums = [
123+
build_datum(prompt_tokens, list(sequence.tokens), list(sequence.logprobs or []), advantage)
124+
for sequence, advantage in zip(sequences, advantages)
125+
]
126+
fwdbwd = trainer.forward_backward(datums, config.loss_fn).result()
127+
trainer.optim_step(types.AdamParams(learning_rate=config.learning_rate, grad_clip_norm=config.grad_clip_norm)).result()
128+
129+
loss = float(fwdbwd.metrics.get("loss:mean", 0.0))
130+
if not math.isfinite(loss):
131+
raise RuntimeError(f"Loss must be finite, got {loss!r}")
132+
write_metric(log_dir, {"phase": "train", "step": step, "loss": loss, "mean_reward": mean_reward, "num_datums": len(datums)})
133+
print(f"[tiny-rl] step={step:02d}/{config.steps} loss={loss:.6f} mean_reward={mean_reward:.2f} datums={len(datums)}")
134+
135+
final_state_path = trainer.save_state("tiny-rl-final").result().path
136+
write_metric(log_dir, {"phase": "final", "step": config.steps, "final_state_path": final_state_path, "mean_reward": mean_reward})
137+
print(f"[tiny-rl] mean_reward={mean_reward:.2f}")
138+
print(f"final_state_path={final_state_path}")
139+
140+
141+
if __name__ == "__main__":
142+
chz.nested_entrypoint(main, allow_hyphens=True)

0 commit comments

Comments
 (0)