Skip to content

Commit b3cb5e1

Browse files
authored
Minor improvements to make RL workflow more robust (#319)
1 parent 3761c0f commit b3cb5e1

File tree

11 files changed

+147
-33
lines changed

11 files changed

+147
-33
lines changed

.github/workflows/examples-calc-x.yml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,14 @@ jobs:
239239
WANDB_API_KEY: ${{ secrets.MSR_WANDB_API_KEY }}
240240
id: calc_x_train_local_model
241241

242+
- name: Validate training with local model
243+
run: |
244+
set -ex
245+
uv run scripts/validate_example_wandb.py ${{ steps.calc_x_train_local_model.outputs.project_name }} ${{ steps.calc_x_train_local_model.outputs.run_name }}
246+
env:
247+
WANDB_BASE_URL: ${{ secrets.MSR_WANDB_BASE_URL }}
248+
WANDB_API_KEY: ${{ secrets.MSR_WANDB_API_KEY }}
249+
242250
- name: Training with LLM Proxy
243251
run: |
244252
set -ex
@@ -254,6 +262,14 @@ jobs:
254262
WANDB_API_KEY: ${{ secrets.MSR_WANDB_API_KEY }}
255263
id: calc_x_train_llm_proxy
256264

265+
- name: Validate training with LLM Proxy
266+
run: |
267+
set -ex
268+
uv run scripts/validate_example_wandb.py ${{ steps.calc_x_train_llm_proxy.outputs.project_name }} ${{ steps.calc_x_train_llm_proxy.outputs.run_name }}
269+
env:
270+
WANDB_BASE_URL: ${{ secrets.MSR_WANDB_BASE_URL }}
271+
WANDB_API_KEY: ${{ secrets.MSR_WANDB_API_KEY }}
272+
257273
- name: Training with external store
258274
run: |
259275
set -euo pipefail
@@ -284,6 +300,14 @@ jobs:
284300
WANDB_API_KEY: ${{ secrets.MSR_WANDB_API_KEY }}
285301
id: calc_x_train_external_store
286302

303+
- name: Validate training with external store
304+
run: |
305+
set -ex
306+
uv run scripts/validate_example_wandb.py ${{ steps.calc_x_train_external_store.outputs.project_name }} ${{ steps.calc_x_train_external_store.outputs.run_name }}
307+
env:
308+
WANDB_BASE_URL: ${{ secrets.MSR_WANDB_BASE_URL }}
309+
WANDB_API_KEY: ${{ secrets.MSR_WANDB_API_KEY }}
310+
287311
- name: Training with role-based environment variables
288312
run: |
289313
set -euo pipefail
@@ -305,3 +329,12 @@ jobs:
305329
env:
306330
WANDB_BASE_URL: ${{ secrets.MSR_WANDB_BASE_URL }}
307331
WANDB_API_KEY: ${{ secrets.MSR_WANDB_API_KEY }}
332+
id: calc_x_train_role_based_env_var
333+
334+
- name: Validate training with role-based environment variables
335+
run: |
336+
set -ex
337+
uv run scripts/validate_example_wandb.py ${{ steps.calc_x_train_role_based_env_var.outputs.project_name }} ${{ steps.calc_x_train_role_based_env_var.outputs.run_name }}
338+
env:
339+
WANDB_BASE_URL: ${{ secrets.MSR_WANDB_BASE_URL }}
340+
WANDB_API_KEY: ${{ secrets.MSR_WANDB_API_KEY }}

.github/workflows/examples-spider.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ jobs:
121121
- name: Validate Spider training
122122
run: |
123123
set -ex
124-
uv run scripts/validate_example_wandb.py ${{ steps.spider_train.outputs.project_name }} ${{ steps.spider_train.outputs.run_name }}
124+
uv run scripts/validate_example_wandb.py ${{ steps.spider_train.outputs.project_name }} ${{ steps.spider_train.outputs.run_name }} --reward-tolerance 5
125125
env:
126126
WANDB_BASE_URL: ${{ secrets.MSR_WANDB_BASE_URL }}
127127
WANDB_API_KEY: ${{ secrets.MSR_WANDB_API_KEY }}

agentlightning/execution/client_server.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
terminate_timeout: float = 10.0,
7272
main_process: Literal["algorithm", "runner"] = "algorithm",
7373
managed_store: bool | None = None,
74+
allowed_exit_codes: Iterable[int] = (0, -15),
7475
) -> None:
7576
"""Configure the strategy.
7677
@@ -94,6 +95,9 @@ def __init__(
9495
LightningStore client/server wrappers automatically. When
9596
`False` the provided `store` is passed directly to the
9697
bundles, allowing callers to manage store wrappers manually.
98+
allowed_exit_codes: Allowed exit codes for subprocesses.
99+
By default, runner can exit gracefully with code 0 or terminated
100+
by SIGTERM (-15).
97101
"""
98102
if role is None:
99103
role_env = os.getenv("AGL_CURRENT_ROLE")
@@ -133,6 +137,7 @@ def __init__(
133137
raise ValueError("main_process='runner' requires n_runners to be 1")
134138
self.main_process = main_process
135139
self.managed_store = resolve_managed_store_flag(managed_store)
140+
self.allowed_exit_codes = tuple(allowed_exit_codes)
136141

137142
async def _execute_algorithm(
138143
self, algorithm: AlgorithmBundle, store: LightningStore, stop_evt: ExecutionEvent
@@ -338,10 +343,10 @@ def _shutdown_processes(
338343

339344
def _check_process_exitcodes(self, processes: Iterable[multiprocessing.Process]) -> None:
340345
"""Raise an error if any managed process exited with a non-zero status."""
341-
failed = [p for p in processes if p.exitcode not in (0, None)]
346+
failed = [p for p in processes if p.exitcode not in self.allowed_exit_codes + (None,)]
342347
if failed:
343348
formatted = ", ".join(f"{p.name or p.pid} (exitcode={p.exitcode})" for p in failed)
344-
raise RuntimeError(f"Subprocesses failed: {formatted}")
349+
raise RuntimeError(f"Subprocesses failed with unexpected exit codes: {formatted}")
345350

346351
def execute(self, algorithm: AlgorithmBundle, runner: RunnerBundle, store: LightningStore) -> None:
347352
logger.info(

agentlightning/runner/agent.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import asyncio
1313
import logging
14+
import random
1415
import threading
1516
import time
1617
from contextlib import suppress
@@ -72,6 +73,7 @@ def __init__(
7273
max_rollouts: Optional[int] = None,
7374
poll_interval: float = 5.0,
7475
heartbeat_interval: float = 10.0,
76+
interval_jitter: float = 0.1,
7577
heartbeat_launch_mode: Literal["asyncio", "thread"] = "asyncio",
7678
) -> None:
7779
"""Initialize the agent runner.
@@ -82,6 +84,9 @@ def __init__(
8284
[`iter`][agentlightning.LitAgentRunner.iter].
8385
poll_interval: Seconds to wait between store polls when no work is available.
8486
heartbeat_interval: Seconds to wait between sending heartbeats to the store.
87+
interval_jitter: Jitter factor for the poll interval. The actual interval will be between
88+
poll_interval - interval_jitter and poll_interval + interval_jitter.
89+
This is to avoid the overload caused by the synchronization of the runners.
8590
heartbeat_launch_mode: Launch mode for the heartbeat loop. Can be "asyncio" or "thread".
8691
"asyncio" is the default and recommended mode. Use "thread" if you are experiencing blocking coroutines.
8792
"""
@@ -90,7 +95,9 @@ def __init__(
9095
self._max_rollouts = max_rollouts
9196
self._poll_interval = poll_interval
9297
self._heartbeat_interval = heartbeat_interval
98+
self._interval_jitter = interval_jitter
9399
self._heartbeat_launch_mode = heartbeat_launch_mode
100+
self._random_state = random.Random()
94101

95102
# Set later
96103
self._agent: Optional[LitAgent[T_task]] = None
@@ -360,7 +367,11 @@ async def heartbeat_loop() -> None:
360367
while not stop_event.is_set():
361368
await self._emit_heartbeat(store)
362369
with suppress(asyncio.TimeoutError):
363-
await asyncio.wait_for(stop_event.wait(), timeout=self._heartbeat_interval)
370+
interval = self._heartbeat_interval + self._random_state.uniform(
371+
-self._interval_jitter, self._interval_jitter
372+
)
373+
interval = max(interval, 0.01)
374+
await asyncio.wait_for(stop_event.wait(), timeout=interval)
364375

365376
task = asyncio.create_task(heartbeat_loop(), name=f"{self.get_worker_id()}-heartbeat")
366377

@@ -379,7 +390,11 @@ def thread_worker() -> None:
379390
asyncio.set_event_loop(loop)
380391
while not stop_evt.is_set():
381392
loop.run_until_complete(self._emit_heartbeat(store))
382-
stop_evt.wait(self._heartbeat_interval)
393+
interval = self._heartbeat_interval + self._random_state.uniform(
394+
-self._interval_jitter, self._interval_jitter
395+
)
396+
interval = max(interval, 0.01)
397+
stop_evt.wait(interval)
383398

384399
thread = threading.Thread(target=thread_worker, name=f"{self.get_worker_id()}-heartbeat", daemon=True)
385400
thread.start()
@@ -402,11 +417,13 @@ async def _sleep_until_next_poll(self, event: Optional[ExecutionEvent] = None) -
402417
event: Optional [`ExecutionEvent`][agentlightning.ExecutionEvent] object that can be used to interrupt the sleep.
403418
If set during the sleep period, the method returns immediately.
404419
"""
420+
interval = self._poll_interval + self._random_state.uniform(-self._interval_jitter, self._interval_jitter)
421+
interval = max(interval, 0.01)
405422
if event is None:
406-
await asyncio.sleep(self._poll_interval)
423+
await asyncio.sleep(interval)
407424
return
408425
current_time = time.time()
409-
next_time = current_time + self._poll_interval
426+
next_time = current_time + interval
410427
while time.time() < next_time:
411428
await asyncio.sleep(0.1)
412429
if event.is_set():

agentlightning/store/client_server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1131,6 +1131,7 @@ def __getstate__(self):
11311131
are excluded as they should not be transferred between processes.
11321132
"""
11331133
return {
1134+
"server_address_root": self.server_address_root,
11341135
"server_address": self.server_address,
11351136
"_retry_delays": self._retry_delays,
11361137
"_health_retry_delays": self._health_retry_delays,
@@ -1145,6 +1146,7 @@ def __setstate__(self, state: Dict[str, Any]):
11451146
Replicating `__init__` logic to create another client instance in the subprocess.
11461147
"""
11471148
self.server_address = state["server_address"]
1149+
self.server_address_root = state["server_address_root"]
11481150
self._sessions = {}
11491151
self._lock = threading.Lock()
11501152
self._retry_delays = state["_retry_delays"]

agentlightning/utils/server_launcher.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from contextlib import asynccontextmanager, suppress
1616
from dataclasses import dataclass
1717
from multiprocessing.process import BaseProcess
18-
from typing import Any, AsyncContextManager, AsyncIterator, Dict, Literal, Optional
18+
from typing import Any, AsyncContextManager, AsyncIterator, Dict, Literal, Optional, cast
1919

2020
import aiohttp
2121
import requests
@@ -65,6 +65,8 @@ class PythonServerLauncherArgs:
6565
"""The timeout to wait for the thread to join."""
6666
process_join_timeout: float = 10.0
6767
"""The timeout to wait for the process to join."""
68+
timeout_keep_alive: int = 30
69+
"""The timeout to keep the connection alive."""
6870

6971

7072
@dataclass
@@ -650,7 +652,7 @@ def __getstate__(self):
650652

651653
def __setstate__(self, state: Dict[str, Any]):
652654
self.app = state["app"]
653-
self.args = state["args"]
655+
self.args = cast(PythonServerLauncherArgs, state["args"])
654656
self.serve_context = state["serve_context"]
655657
self._host = state["_host"]
656658
self._port = state["_port"]
@@ -796,6 +798,7 @@ def _create_uvicorn_server(self) -> uvicorn.Server:
796798
log_level=self.args.log_level,
797799
access_log=self.args.access_log,
798800
loop="asyncio",
801+
timeout_keep_alive=self.args.timeout_keep_alive,
799802
)
800803
return uvicorn.Server(config)
801804

agentlightning/verl/daemon.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,12 @@
1818
from tensordict import TensorDict
1919
from verl import DataProto
2020

21-
from agentlightning import LLM, AgentLightningServer, NamedResources, RolloutLegacy, setup_logging
21+
from agentlightning import LLM, AgentLightningServer, NamedResources, RolloutLegacy
2222
from agentlightning.adapter.triplet import TracerTraceToTriplet, TraceToTripletBase
2323
from agentlightning.llm_proxy import LLMProxy, ModelConfig
2424
from agentlightning.store.base import LightningStore
2525
from agentlightning.types import Rollout, RolloutConfig, Task
2626

27-
setup_logging()
28-
2927
__all__ = [
3028
"AgentModeDaemon",
3129
"get_left_padded_ids_and_attention_mask",

examples/calc_x/calc_agent.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,19 @@ async def calc_agent(task: MathProblem, llm: agl.LLM) -> None:
8484
try:
8585
output_format = "Output the answer when you are ready. The answer should be surrounded by three sharps (`###`), in the form of ### ANSWER: <answer> ###."
8686
prompt = task["question"] + " " + output_format
87-
result = await calc_agent.run(task=prompt)
87+
# Sometimes MCP tools can timeout. In that case, the whole agent will block.
88+
# We thus set a timeout of 5 minutes so that the agent will not block indefinitely.
89+
result = await asyncio.wait_for(calc_agent.run(task=prompt), timeout=300.0)
8890
# evaluate
8991
last_message = cast(str, result.messages[-1].content) # type: ignore
9092
answer = re.search(r"###\s*ANSWER:\s*(.+?)(\s*###|$)", last_message)
9193
if answer:
9294
answer = answer.group(1)
9395
else:
9496
answer = last_message
97+
except asyncio.TimeoutError as e:
98+
print("Timeout occurred. Error:", str(e))
99+
answer = "None"
95100
except Exception as e:
96101
print("Failure:", str(e))
97102
answer = "None"

examples/calc_x/train_calc_agent.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import argparse
3232
import os
33+
import uuid
3334
from datetime import datetime
3435
from typing import Any, Dict, Optional, cast
3536

@@ -146,20 +147,25 @@ def train(
146147
if ci or ci_fast:
147148
# Config the experiment name and project name so that they are available to CI
148149
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
149-
EXPERIMENT_NAME = f"calc_x_{timestamp}"
150+
random_suffix = uuid.uuid4().hex[:8]
151+
EXPERIMENT_NAME = f"calc_x_{timestamp}_{random_suffix}"
150152

151153
PROJECT_NAME = "AgentLightningCI"
152154

153-
# Simulate writing to $GITHUB_OUTPUT if it’s set
154-
github_output = os.getenv("GITHUB_OUTPUT")
155-
if github_output:
156-
with open(github_output, "a") as f:
157-
f.write(f"project_name={PROJECT_NAME}\n")
158-
f.write(f"run_name={EXPERIMENT_NAME}\n")
155+
# Skip this step if AGL_CURRENT_ROLE is runner
156+
agl_current_role = os.getenv("AGL_CURRENT_ROLE")
159157

160-
print("Set environment variables:")
161-
print(f"PROJECT_NAME={PROJECT_NAME}")
162-
print(f"EXPERIMENT_NAME={EXPERIMENT_NAME}")
158+
if agl_current_role != "runner":
159+
# Simulate writing to $GITHUB_OUTPUT if it’s set
160+
github_output = os.getenv("GITHUB_OUTPUT")
161+
if github_output:
162+
with open(github_output, "a") as f:
163+
f.write(f"project_name={PROJECT_NAME}\n")
164+
f.write(f"run_name={EXPERIMENT_NAME}\n")
165+
166+
print("Set environment variables:")
167+
print(f"PROJECT_NAME={PROJECT_NAME}")
168+
print(f"EXPERIMENT_NAME={EXPERIMENT_NAME}")
163169

164170
# Keep it tiny/light without adding new knobs
165171
config["actor_rollout_ref"]["rollout"]["gpu_memory_utilization"] = 0.8
@@ -210,6 +216,7 @@ def main():
210216
default="",
211217
help="Connect to an external store instead of creating a new one in memory",
212218
)
219+
parser.add_argument("--debug", action="store_true", help="Enable debug logging")
213220

214221
args = parser.parse_args()
215222

@@ -224,6 +231,8 @@ def main():
224231
if args.ci_fast:
225232
args.ci = True
226233

234+
agl.setup_logging("DEBUG" if args.debug else "INFO")
235+
227236
train(
228237
train_file=args.train_file,
229238
val_file=args.val_file,

0 commit comments

Comments
 (0)