Skip to content

Commit b9419dd

Browse files
committed
Add training e2e eval flows
1 parent 1f8776e commit b9419dd

7 files changed

Lines changed: 646 additions & 9 deletions

File tree

Makefile

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,31 @@ HOST ?= 127.0.0.1
1313
PORT ?= 9003
1414
# The fully qualified base URL used by local CLI tools and clients
1515
BASE_URL ?= http://$(HOST):$(PORT)
16-
TEST_PYTHONPATH ?= examples/sft/pig-latin
16+
UNIT_TESTS ?= tests.test_gateway_paths tests.test_snapshot_agent tests.test_trainer_optimizer_correctness tests.test_worker_launch_processor
17+
# Only forward BASE_URL to e2e when the user supplied it. The Makefile default
18+
# is for local CLI usage; e2e should start its own backend by default.
19+
TRAINING_TEST_BASE_URL ?= $(if $(filter environment command line,$(origin BASE_URL)),$(BASE_URL),)
20+
TRAINING_TEST_EXTRA ?= gpu
21+
TRAINING_TEST_ARGS ?=
22+
PIGLATIN_TEST_PYTHONPATH ?= examples/sft/pig-latin
23+
24+
# CUDA_VISIBLE_DEVICES can be provided either as an environment variable or as a
25+
# Make variable, and is inherited by the backend/eval subprocesses.
26+
ifneq ($(origin CUDA_VISIBLE_DEVICES),undefined)
27+
export CUDA_VISIBLE_DEVICES
28+
endif
1729

1830
help:
1931
@echo "make server # $(BASE_MODEL), SAMPLING_BACKEND=$(SAMPLING_BACKEND), port $(PORT)"
2032
@echo "make server BASE_MODEL=google/gemma-4-e2b SAMPLING_BACKEND=vllm"
2133
@echo "VLLM_ARCHITECTURE_OVERRIDE=Gemma4ForCausalLM make vllm BASE_MODEL=google/gemma-4-e2b"
22-
@echo "make test | lint | fmt"
34+
@echo "make test # fast unit tests"
35+
@echo "make test e2e lora-sft|lora-rl|fft-sft # starts a local backend, then runs existing training/eval examples"
36+
@echo "make test e2e lora-sft BASE_URL=http://host:9003"
37+
@echo "CUDA_VISIBLE_DEVICES=0 make test e2e fft-sft"
38+
@echo "make test e2e fft-sft TRAINING_TEST_ARGS='gsm8k.steps=10'"
39+
@echo "make test piglatin # pig-latin example end-to-end tests"
40+
@echo "make lint | fmt"
2341

2442
# ---------------------------------------------------------------------------
2543
# Server
@@ -42,14 +60,40 @@ ifeq (cli,$(firstword $(MAKECMDGOALS)))
4260
$(eval $(CLI_ARGS):;@:)
4361
endif
4462

63+
ifeq (test,$(firstword $(MAKECMDGOALS)))
64+
TEST_MODE := $(word 2,$(MAKECMDGOALS))
65+
TEST_SCENARIO := $(word 3,$(MAKECMDGOALS))
66+
TEST_ARGS := $(wordlist 2,$(words $(MAKECMDGOALS)),$(MAKECMDGOALS))
67+
ifneq ($(TEST_ARGS),)
68+
$(eval $(TEST_ARGS):;@:)
69+
endif
70+
endif
71+
4572
cli:
4673
@cd dev/tools && BASE_URL="$(BASE_URL)" uv run python cli.py $(CLI_ARGS)
4774

4875
# ---------------------------------------------------------------------------
4976
# Dev
5077
# ---------------------------------------------------------------------------
5178
test:
52-
PYTHONPATH="$(TEST_PYTHONPATH)" uv --project examples run python -m unittest discover -s tests
79+
@mode="$(TEST_MODE)"; \
80+
scenario="$(TEST_SCENARIO)"; \
81+
if [ -z "$$mode" ] || [ "$$mode" = "unit" ]; then \
82+
uv run --frozen --extra cpu python -m unittest $(UNIT_TESTS); \
83+
elif [ "$$mode" = "e2e" ]; then \
84+
if [ -z "$$scenario" ]; then \
85+
echo "Missing e2e scenario. Expected lora-sft, lora-rl, or fft-sft."; \
86+
exit 2; \
87+
fi; \
88+
set -- "scenario=$$scenario" "backend.uv_extra=$(TRAINING_TEST_EXTRA)"; \
89+
if [ -n "$(TRAINING_TEST_BASE_URL)" ]; then set -- "$$@" "backend.base_url=$(TRAINING_TEST_BASE_URL)"; fi; \
90+
uv run --extra "$(TRAINING_TEST_EXTRA)" python scripts/run_training_e2e.py "$$@" $(TRAINING_TEST_ARGS); \
91+
elif [ "$$mode" = "piglatin" ]; then \
92+
PYTHONPATH="$(PIGLATIN_TEST_PYTHONPATH)" uv --project examples run python -m unittest discover -s tests; \
93+
else \
94+
echo "Unknown test mode '$$mode'. Expected unit, e2e, or piglatin."; \
95+
exit 2; \
96+
fi
5397

5498
lint:
5599
uvx ruff check .

examples/sft/gsm8k/vllm_eval.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,19 @@ def main() -> None:
2222
parser = argparse.ArgumentParser()
2323
parser.add_argument("--path", required=True)
2424
parser.add_argument("--data", default="gsm8k_test.json")
25+
parser.add_argument("--gpu-memory-utilization", type=float, default=0.85)
2526
args = parser.parse_args()
2627

2728
with open(args.data) as f:
2829
data = json.load(f)
2930

30-
llm = LLM(model=args.path, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=1024, enforce_eager=True)
31+
llm = LLM(
32+
model=args.path,
33+
dtype="bfloat16",
34+
gpu_memory_utilization=args.gpu_memory_utilization,
35+
max_model_len=1024,
36+
enforce_eager=True,
37+
)
3138
sampling_params = SamplingParams(temperature=0.0, max_tokens=256, stop=["\nQuestion:"])
3239
start = time.time()
3340
outputs = llm.generate([datum["prompt"] for datum in data], sampling_params)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ description = "Open-RL server and training runtime."
99
readme = "README.md"
1010
requires-python = ">=3.12, <3.13"
1111
dependencies = [
12+
"chz>=0.4.0",
1213
"fastapi",
1314
"opentelemetry-api",
1415
"opentelemetry-sdk",

0 commit comments

Comments
 (0)