|
| 1 | +"""Train a cyber agent on an OpenRange world pool with rLLM's ``AgentTrainer``. |
| 2 | +
|
| 3 | +This is the rLLM half of "one scaffold, two modes": the *same* agent loop that |
| 4 | +``examples/codex_eval.py`` evaluates with is trained here, swapping only the |
| 5 | +sampler. ``openrange_rllm`` maps each OpenRange episode onto rLLM's |
| 6 | +``Episode``/``Step`` and exposes the policy as an ``@rllm.rollout`` flow; rLLM's |
| 7 | +gateway captures token ids and logprobs, GRPO does the rest. The reward is the |
| 8 | +pack's own dense subgoal ladder (no reward logic here). |
| 9 | +
|
| 10 | +A pool of command-injection "company" worlds becomes an rLLM dataset (one row per |
| 11 | +pentest task, carrying its ``snapshot_id``/``task_id``); ``snapshot_resolver`` |
| 12 | +maps each sampled rLLM task back to its world. The agent reaches the live webapp |
| 13 | +over HTTP from a host shell (PROCESS backing) and composes ``curl`` itself. |
| 14 | +
|
| 15 | +Run on one CUDA GPU through rLLM's verl backend. Validated end to end on an |
| 16 | +A100-40GB inside the maintainers' ``verlai/verl:vllm011.latest`` image (torch 2.8 |
| 17 | +/ vLLM 0.11 / flash-attn):: |
| 18 | +
|
| 19 | + python -m examples.rllm_grpo_cyber \ |
| 20 | + rllm/backend=verl algorithm.adv_estimator=grpo \ |
| 21 | + +model.name=Qwen/Qwen2.5-7B-Instruct \ |
| 22 | + actor_rollout_ref.model.path=Qwen/Qwen2.5-7B-Instruct \ |
| 23 | + actor_rollout_ref.model.lora_rank=32 \ |
| 24 | + actor_rollout_ref.model.lora_alpha=32 \ |
| 25 | + actor_rollout_ref.actor.use_dynamic_bsz=True \ |
| 26 | + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=16384 \ |
| 27 | + actor_rollout_ref.actor.use_kl_loss=False \ |
| 28 | + actor_rollout_ref.rollout.name=vllm \ |
| 29 | + actor_rollout_ref.rollout.mode=async \ |
| 30 | + actor_rollout_ref.rollout.enforce_eager=True \ |
| 31 | + actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ |
| 32 | + actor_rollout_ref.rollout.n=4 \ |
| 33 | + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \ |
| 34 | + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \ |
| 35 | + trainer.n_gpus_per_node=1 data.train_batch_size=2 \ |
| 36 | + rllm.trainer.total_batches=1 |
| 37 | +
|
| 38 | +Gotchas (both cost real debugging): |
| 39 | +
|
| 40 | +- LoRA uses the **flat** keys ``lora_rank`` / ``lora_alpha``. The nested |
| 41 | + ``lora.rank`` is silently ignored, which means full fine-tuning — a 7B then |
| 42 | + OOMs a 40GB card, whereas with LoRA on it fits comfortably. |
| 43 | +- OpenRange currently requires Python **3.14** (PEP 758 ``except`` syntax) but the |
| 44 | + verl GPU stack ships only Python **3.12** wheels, so they cannot share a |
| 45 | + process. Running this needs OpenRange made 3.12-compatible (a 3-line change) or |
| 46 | + 3.14 wheels for torch/vLLM/flash-attn (not available yet). |
| 47 | +
|
| 48 | +Building the pool + registering the dataset is CPU-only; ``trainer.train()`` is |
| 49 | +the CUDA boundary (Ray + vLLM + FSDP under the verl backend). |
| 50 | +""" |
| 51 | + |
| 52 | +from __future__ import annotations |
| 53 | + |
| 54 | +import subprocess |
| 55 | +from collections.abc import Mapping |
| 56 | +from typing import Any |
| 57 | + |
| 58 | +from cyber_webapp import WebappPack |
| 59 | +from cyber_webapp.difficulty import world_difficulty |
| 60 | +from openrange_pack_sdk import Snapshot |
| 61 | +from openrange_rllm import ( |
| 62 | + build_rllm_dataset_rows, |
| 63 | + make_evaluator, |
| 64 | + make_rollout, |
| 65 | + snapshot_resolver, |
| 66 | +) |
| 67 | + |
| 68 | +from openrange import EpisodeService, EvalPool, WorldPool |
| 69 | +from openrange.core.sandbox import CommandResult |
| 70 | + |
| 71 | +_DATASET = "openrange_cyber" |
| 72 | +_FAMILY = "webapp.pentest" |
| 73 | +_RUN_ROOT = "or-runs/cyber-rllm" |
| 74 | + |
| 75 | + |
| 76 | +def _company(seed: int) -> dict[str, Any]: |
| 77 | + return { |
| 78 | + "world": {"goal": "recover the hidden flag from the internal estate"}, |
| 79 | + "pack": {"id": "webapp"}, |
| 80 | + "runtime": {"tick": {"mode": "off"}}, |
| 81 | + "npc": [], |
| 82 | + "seed": seed, |
| 83 | + "topology": "chain", |
| 84 | + } |
| 85 | + |
| 86 | + |
| 87 | +def _difficulty(snapshot: Snapshot) -> float: |
| 88 | + return float(world_difficulty(snapshot.graph)) |
| 89 | + |
| 90 | + |
| 91 | +class _HostRun: |
| 92 | + def run(self, command: str, *, timeout: float = 120.0) -> CommandResult: |
| 93 | + done = subprocess.run( |
| 94 | + ["bash", "-lc", command], |
| 95 | + capture_output=True, |
| 96 | + text=True, |
| 97 | + timeout=timeout, |
| 98 | + check=False, |
| 99 | + ) |
| 100 | + return CommandResult(done.returncode, done.stdout + done.stderr) |
| 101 | + |
| 102 | + def close(self) -> None: |
| 103 | + return None |
| 104 | + |
| 105 | + |
| 106 | +def _host_bind(_surface: Mapping[str, Any]) -> _HostRun: |
| 107 | + return _HostRun() |
| 108 | + |
| 109 | + |
| 110 | +def main() -> None: |
| 111 | + import hydra |
| 112 | + from omegaconf import DictConfig |
| 113 | + from rllm.data.dataset import DatasetRegistry |
| 114 | + from rllm.trainer import AgentTrainer |
| 115 | + |
| 116 | + @hydra.main( # type: ignore[untyped-decorator] |
| 117 | + config_path="pkg://rllm.trainer.config", |
| 118 | + config_name="unified", |
| 119 | + version_base=None, |
| 120 | + ) |
| 121 | + def _train(config: DictConfig) -> None: |
| 122 | + pack = WebappPack() |
| 123 | + train_pool = WorldPool.seed( |
| 124 | + pack, |
| 125 | + [_company(seed) for seed in range(4)], |
| 126 | + difficulty_fn=_difficulty, |
| 127 | + family=_FAMILY, |
| 128 | + max_size=8, |
| 129 | + ) |
| 130 | + val_pool = EvalPool.seed( |
| 131 | + pack, |
| 132 | + [_company(seed) for seed in (7, 8)], |
| 133 | + difficulty_fn=_difficulty, |
| 134 | + family=_FAMILY, |
| 135 | + ) |
| 136 | + DatasetRegistry.register_dataset( |
| 137 | + _DATASET, |
| 138 | + build_rllm_dataset_rows(train_pool.snapshots(), family=_FAMILY), |
| 139 | + "train", |
| 140 | + ) |
| 141 | + DatasetRegistry.register_dataset( |
| 142 | + _DATASET, |
| 143 | + build_rllm_dataset_rows(val_pool.snapshots(), family=_FAMILY), |
| 144 | + "test", |
| 145 | + ) |
| 146 | + resolve = snapshot_resolver([*train_pool.snapshots(), *val_pool.snapshots()]) |
| 147 | + service = EpisodeService(pack, _RUN_ROOT) |
| 148 | + trainer = AgentTrainer( |
| 149 | + backend=config.rllm.get("backend", "verl"), |
| 150 | + agent_flow=make_rollout(service, resolve, bind_run=_host_bind), |
| 151 | + evaluator=make_evaluator(), |
| 152 | + config=config, |
| 153 | + train_dataset=DatasetRegistry.load_dataset(_DATASET, "train"), |
| 154 | + val_dataset=DatasetRegistry.load_dataset(_DATASET, "test"), |
| 155 | + ) |
| 156 | + trainer.train() |
| 157 | + |
| 158 | + _train() |
| 159 | + |
| 160 | + |
| 161 | +if __name__ == "__main__": # pragma: no cover |
| 162 | + main() |
0 commit comments