Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ env/
# Training Logs & Plots
scratch/
docs/scratch/
VERSION
*.txt
*.png
*.log
Expand Down
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ test:
fi; \
set -- "scenario=$$scenario" "uv_extra=$(TRAINING_TEST_EXTRA)"; \
if [ -n "$(TRAINING_TEST_BASE_URL)" ]; then set -- "$$@" "base_url=$(TRAINING_TEST_BASE_URL)"; fi; \
kubectl delete pods -l snapshot-agent=true --force --grace-period=0 2>/dev/null || true; \
uv run --extra "$(TRAINING_TEST_EXTRA)" python scripts/run_training_e2e.py "$$@" $(TRAINING_TEST_ARGS); \
elif [ "$$mode" = "piglatin" ]; then \
PYTHONPATH="$(PIGLATIN_TEST_PYTHONPATH)" uv --project examples run python -m unittest tests.test_piglatin_qwen tests.test_piglatin_gemma; \
Expand All @@ -113,7 +114,7 @@ fmt:
# Deployment (GKE)
# ---------------------------------------------------------------------------
GCP_PROJECT ?= cdrollouts-sunilarora
IMAGE_TAG ?= latest
IMAGE_TAG ?= $(shell git rev-parse --short HEAD 2>/dev/null || cat VERSION 2>/dev/null || echo latest)

build-images:
DOCKER_BUILDKIT=1 docker build -t gcr.io/$(GCP_PROJECT)/open-rl-server:$(IMAGE_TAG) -f src/server/Dockerfile .
Expand All @@ -122,6 +123,9 @@ build-images:
push-images:
docker push gcr.io/$(GCP_PROJECT)/open-rl-server:$(IMAGE_TAG)
docker push gcr.io/$(GCP_PROJECT)/open-rl-gateway:$(IMAGE_TAG)
kubectl set image deployment/open-rl-gateway gateway=gcr.io/$(GCP_PROJECT)/open-rl-gateway:$(IMAGE_TAG) 2>/dev/null || true
kubectl set image daemonset/open-rl-snapshot-agent snapshot-agent=gcr.io/$(GCP_PROJECT)/open-rl-server:$(IMAGE_TAG) 2>/dev/null || true
kubectl set env deployment/open-rl-gateway OPEN_RL_WORKER_IMAGE=gcr.io/$(GCP_PROJECT)/open-rl-server:$(IMAGE_TAG) 2>/dev/null || true

deploy:
kubectl apply -k k8s/deploy/distributed-lustre/
Expand Down Expand Up @@ -169,6 +173,7 @@ REMOTE_HOST ?= <PLACE_HOLDER_FOR_REMOTE_HOST_ADDRESS>

# Push local workspace changes to the remote VM
push-vm:
@git rev-parse --short HEAD > VERSION 2>/dev/null || true
rsync -avz --exclude '.git' --exclude '.venv' --exclude '__pycache__' --exclude '*.pyc' --exclude '.DS_Store' --exclude 'scratch' ./ $(REMOTE_HOST):~/open-rl

# Pull changes from the remote VM back to the local workspace
Expand Down
32 changes: 32 additions & 0 deletions docs/setup/gke-fft-timeslice.md
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,38 @@ Notes:
plugin's [time-slicing config](https://github.com/NVIDIA/k8s-device-plugin#shared-access-to-gpus-with-cuda-time-slicing)
(`replicas: 2`) plus the node label.

## Setup 1.5: Install and deploy llm-d Snapshot Agent DaemonSet

Because `open-rl-snapshot-agent` runs with `--backend llmd` in cluster deployments, it delegates physical kernel-level CUDA process freezing and unfreezing (`cuda-checkpoint`) to `llmd-snapshot-agent` over gRPC on `127.0.0.1:9001`.

To build and deploy the official `llmd-snapshot-agent` DaemonSet on GPU nodes:

1. **Clone the official `llm-d-rl-time-slicing` repository:**
```bash
git clone https://github.com/llm-d-incubation/llm-d-rl-time-slicing.git ~/.cache/checkouts/github.com/llm-d-incubation/llm-d-rl-time-slicing
cd ~/.cache/checkouts/github.com/llm-d-incubation/llm-d-rl-time-slicing
```

2. **Build and push the Go Daemon container image:**
*(Note: The Dockerfile requires a pre-built `bin/` directory. Ensure `bin/` exists before running Docker build).*
```bash
mkdir -p bin/
DOCKER_BUILDKIT=1 docker build -t gcr.io/<YOUR_GCP_PROJECT>/llmd-snapshot-agent:latest -f deploy/snapshot-agent/Dockerfile .
docker push gcr.io/<YOUR_GCP_PROJECT>/llmd-snapshot-agent:latest
```

3. **Deploy the Helm Chart into `timeslice-system`:**
```bash
helm template snapshot-agent deploy/snapshot-agent/ \
--namespace timeslice-system \
--set image.repository=gcr.io/<YOUR_GCP_PROJECT>/llmd-snapshot-agent \
--set image.tag=latest \
--set tolerations[0].key=nvidia.com/gpu \
--set tolerations[0].operator=Exists \
--set tolerations[0].effect=NoSchedule | kubectl apply -f -
```
Verify both DaemonSet Pods (`llmd-snapshot-agent-*`) transition to `Running` and actively listen on TCP `9001` across all target GPU nodes.

## Setup 2: Build, push, and deploy OpenRL

```bash
Expand Down
13 changes: 8 additions & 5 deletions examples/sft/gsm8k/gsm8k_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def make_datum(row: dict) -> tinker.Datum:
loss_fn_inputs=cast(Any, {"target_tokens": tokens[1:], "weights": [float(w) for w in weights[1:]]}),
)

return SupervisedDatasetFromHFDataset(dataset, self.batch_size, map_fn=make_datum), None
eval_dataset = load_dataset("openai/gsm8k", "main", split="test[:16]")
return (
SupervisedDatasetFromHFDataset(dataset, self.batch_size, map_fn=make_datum),
SupervisedDatasetFromHFDataset(eval_dataset, self.batch_size, map_fn=make_datum),
)


@chz.chz
Expand All @@ -52,8 +56,9 @@ class Config:
rank: int = 32
max_len: int = 640
seed: int = 0
max_steps: int | None = None
max_steps: int | None = 10
save_every: int = 0
eval_every: int = 5
behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "delete"


Expand All @@ -76,7 +81,7 @@ def main(config: Config) -> None:
lora_rank=config.rank,
base_url=config.base_url,
save_every=config.save_every,
eval_every=0,
eval_every=config.eval_every,
infrequent_eval_every=0,
max_steps=config.max_steps,
)
Expand All @@ -87,8 +92,6 @@ def main(config: Config) -> None:
checkpoint = checkpoint_utils.get_last_checkpoint(config.log_path, required_key="state_path")
if checkpoint is not None:
path = checkpoint.sampler_path or checkpoint.state_path
if path and path.startswith("tinker://"):
path = str(Path(os.getenv("OPEN_RL_TMP_DIR", "/tmp/open-rl")) / "sampler_full" / path.removeprefix("tinker://"))
if path:
print(f"eval_model_path={path}")

Expand Down
65 changes: 46 additions & 19 deletions examples/sft/gsm8k/vllm_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import os
import re
import time

Expand All @@ -17,37 +18,63 @@ def extract(text: str) -> str | None:


def main() -> None:
from vllm import LLM, SamplingParams
from tinker import ServiceClient, types
from tinker_cookbook.tokenizer_utils import get_tokenizer

parser = argparse.ArgumentParser()
parser.add_argument("--path", required=True)
parser.add_argument("--path", required=True, action="append", help="One or more URI paths to evaluate concurrently")
parser.add_argument("--base-model", default="Qwen/Qwen2.5-0.5B")
parser.add_argument("--base-url", default=os.getenv("TINKER_BASE_URL", os.getenv("BASE_URL", "http://127.0.0.1:8000")))
parser.add_argument("--data", default="gsm8k_test.json")
parser.add_argument("--gpu-memory-utilization", type=float, default=0.85)
parser.add_argument("--microbatch-size", type=int, default=10, help="Number of evaluation problems to dispatch per micro-batch")
parser.add_argument("--min-accuracy", type=float, default=0.0, help="exit nonzero if accuracy falls below this fraction")
args = parser.parse_args()

with open(args.data) as f:
data = json.load(f)

llm = LLM(
model=args.path,
dtype="bfloat16",
gpu_memory_utilization=args.gpu_memory_utilization,
max_model_len=1024,
enforce_eager=True,
)
sampling_params = SamplingParams(temperature=0.0, max_tokens=256, stop=["\nQuestion:"])
paths = args.path if isinstance(args.path, list) else [args.path]
client = ServiceClient(api_key=os.getenv("TINKER_API_KEY", "tml-dummy-key"), base_url=args.base_url)
samplers = [client.create_sampling_client(p) for p in paths]
tokenizer = get_tokenizer(args.base_model)

sampling_params = types.SamplingParams(temperature=0.0, max_tokens=256)
start = time.time()
outputs = llm.generate([datum["prompt"] for datum in data], sampling_params)

import asyncio

async def run_evals():
outputs_by_sampler = [[] for _ in paths]
batch_size = args.microbatch_size
for i in range(0, len(data), batch_size):
chunk = data[i : i + batch_size]
for s_idx, sampler in enumerate(samplers):
tasks = [
sampler.sample_async(
prompt=types.ModelInput.from_ints(tokens=tokenizer.encode(datum["prompt"], add_special_tokens=False)),
num_samples=1,
sampling_params=sampling_params,
)
for datum in chunk
]
res_list = await asyncio.gather(*tasks)
for res in res_list:
seqs = res.sequences
outputs_by_sampler[s_idx].append(tokenizer.decode(seqs[0].tokens) if seqs else "")
return outputs_by_sampler

outputs_by_sampler = asyncio.run(run_evals())

elapsed = time.time() - start
correct = sum(int(extract(output.outputs[0].text) == datum["gold"]) for datum, output in zip(data, outputs, strict=True))
accuracy = correct / len(data)

print("***************************************************************")
print(f"[VLLM] {args.path} 0-shot GSM8K acc = {accuracy:.1%} on {len(data)} problems in {elapsed:.1f}s")
print("***************************************************************")
if accuracy < args.min_accuracy:
raise SystemExit(f"GSM8K accuracy {accuracy:.1%} is below the required {args.min_accuracy:.1%}")
for path, outputs in zip(paths, outputs_by_sampler, strict=True):
correct = sum(int(extract(text) == datum["gold"]) for datum, text in zip(data, outputs, strict=True))
accuracy = correct / len(data)
print("***************************************************************")
print(f"[SAMPLER] {path} 0-shot GSM8K acc = {accuracy:.1%} on {len(data)} problems in {elapsed:.1f}s")
print("***************************************************************")
if accuracy < args.min_accuracy:
raise SystemExit(f"GSM8K accuracy {accuracy:.1%} for {path} is below the required {args.min_accuracy:.1%}")


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion k8s/deploy/distributed-fft-timeslice/04-gateway.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ spec:
containers:
- name: gateway
image: ghcr.io/gke-labs/open-rl/gateway:latest
imagePullPolicy: Always
imagePullPolicy: IfNotPresent
command: ["uv", "run", "uvicorn", "server.gateway:app", "--host", "0.0.0.0", "--port", "8000"]
ports:
- containerPort: 8000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ data:
containers:
- name: trainer-worker
image: gcr.io/cdrollouts-sunilarora/open-rl-server:latest
imagePullPolicy: Always
imagePullPolicy: IfNotPresent
command: ["uv", "run", "python", "-m", "server.training_requests_processor"]
env:
- name: REDIS_URL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ spec:
containers:
- name: snapshot-agent
image: ghcr.io/gke-labs/open-rl/server:latest
imagePullPolicy: Always
imagePullPolicy: IfNotPresent
command: ["uv", "run", "python", "-m", "snapshot_agent.serve"]
args: ["--listen-host", "0.0.0.0", "--port", "9753", "--backend", "llmd", "--llmd-snapshot-endpoint", "127.0.0.1:9001"]
ports:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ data:
containers:
- name: sampler-worker
image: gcr.io/cdrollouts-sunilarora/open-rl-server:latest
imagePullPolicy: Always
imagePullPolicy: IfNotPresent
command: ["uv", "run", "python", "-u", "-m", "server.vllm_sampler"]
env:
- name: REDIS_URL
Expand Down
Loading
Loading