Skip to content

Commit 8f4bbaf

Browse files
authored
[example] Add fully async example (#258)
* [example] Add fully async example * add README
1 parent 0445a84 commit 8f4bbaf

File tree

7 files changed

+428
-4
lines changed

7 files changed

+428
-4
lines changed

examples/fully_async/README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
## Fully Asynchronous Rollout Example
2+
3+
This example shows a simple way to make rollout generation **fully asynchronous**: a single global worker is created once and then keeps running in the background, continuously pulling prompts and launching generation tasks. Training only needs to fetch already finished results. This removes the per‑step wait that happens in the normal synchronous style.
4+
5+
### Files
6+
* `fully_async_rollout.py`: global async worker + `generate_rollout_fully_async` entry.
7+
* `run-qwen3-4b-fully_async.sh`: example launch script with Qwen3‑4B.
8+
9+
### Prerequisite
10+
First set up model & environment following [Example: Qwen3-4B Model](../../docs/en/models/qwen3-4B.md).
11+
12+
### Quick Start
13+
```bash
14+
cd slime
15+
bash examples/fully_async/run-qwen3-4b-fully_async.sh
16+
```
17+
You should see log lines like:
18+
```
19+
Creating new global async worker...
20+
Continuous async rollout worker started
21+
```
22+
23+
### How It Works (Very Short)
24+
* First call: create `AsyncRolloutWorker` (thread + asyncio loop).
25+
* Loop keeps up to `--rollout-batch-size` tasks in flight using `generate_and_rm_group`.
26+
* Completed groups are pushed into a queue; caller drains until it has enough samples.
27+
* Worker is stopped automatically at process exit.
28+
29+
### Limitations
30+
* No evaluation mode.
31+
* Ordering is best effort (sorted at the end by index).
32+
* Minimal error handling.
33+
34+
### Config Differences (2 Key Points)
35+
To enable the fully async pattern there are only two changes compared to a normal run:
36+
37+
1. Use the async training driver: `train_async.py` (not `train.py`).
38+
2. Set the rollout function path:
39+
```bash
40+
--rollout-function-path fully_async_rollout.generate_rollout_fully_async
41+
```
42+
43+
Why is it still "fully" async although `train_async.py` itself schedules rollouts step‑by‑step?
44+
45+
Because the real generation work is done by a **persistent background worker** created in `generate_rollout_fully_async`. Each call from `train_async.py` only drains already completed samples from the worker's output queue; the worker has been continuously generating since the first call. Thus rollout production (model inference) and training consume happen in parallel with minimal waiting.
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
import asyncio
2+
import threading
3+
import queue
4+
import time
5+
from typing import List
6+
7+
from slime.utils.async_utils import run
8+
from slime.utils.types import Sample
9+
10+
# Import core functions from sglang_rollout directly to avoid code duplication
11+
from slime.rollout.sglang_rollout import generate_and_rm_group, GenerateState
12+
13+
# Global worker manager
14+
_global_worker = None
15+
_worker_lock = threading.Lock()
16+
17+
18+
def get_global_worker(args, data_buffer):
19+
"""Get or create global worker"""
20+
global _global_worker
21+
with _worker_lock:
22+
if _global_worker is None or not _global_worker.worker_thread.is_alive():
23+
print("Creating new global async worker...")
24+
_global_worker = AsyncRolloutWorker(args, data_buffer, concurrency=args.sglang_server_concurrency)
25+
_global_worker.start()
26+
return _global_worker
27+
28+
29+
def stop_global_worker():
30+
"""Stop global worker"""
31+
global _global_worker
32+
with _worker_lock:
33+
if _global_worker is not None:
34+
_global_worker.stop()
35+
_global_worker = None
36+
37+
38+
class AsyncRolloutWorker:
39+
"""
40+
Simplified asynchronous rollout worker, using threads instead of processes
41+
Supports continuous running, independent of rollout function lifecycle
42+
"""
43+
44+
def __init__(self, args, data_buffer, concurrency=10):
45+
self.args = args
46+
self.data_buffer = data_buffer # Directly save data_buffer reference
47+
self.concurrency = concurrency
48+
self.running = True
49+
self.output_queue = queue.Queue(maxsize=1000) # Continuous output queue
50+
self.worker_thread = None
51+
self.state = GenerateState(args)
52+
53+
async def continuous_worker_loop(self):
54+
"""Continuous work loop - constantly get data from data_buffer and process"""
55+
print("Continuous async rollout worker started")
56+
57+
active_tasks = set()
58+
max_concurrent_tasks = self.args.rollout_batch_size
59+
group_id_counter = 0
60+
61+
while self.running:
62+
try:
63+
# Clean up completed tasks
64+
if active_tasks:
65+
done_tasks = {task for task in active_tasks if task.done()}
66+
for task in done_tasks:
67+
try:
68+
task.result() # Results are already handled in callbacks
69+
except Exception as e:
70+
print(f"Task failed with exception: {e}")
71+
active_tasks -= done_tasks
72+
73+
# If active task count hasn't reached limit, try to get new data and start tasks
74+
while len(active_tasks) < max_concurrent_tasks and self.running:
75+
samples = self.data_buffer.get_samples(1)
76+
77+
for group in samples:
78+
group_id = group_id_counter
79+
group_id_counter += 1
80+
81+
# Create new async task
82+
task = asyncio.create_task(
83+
generate_and_rm_group(
84+
self.args,
85+
group,
86+
sampling_params=self.state.sampling_params.copy(),
87+
evaluation=False,
88+
)
89+
)
90+
91+
# Add completion callback
92+
def make_callback(gid):
93+
def task_done_callback(task):
94+
result = task.result()
95+
self.output_queue.put((gid, result))
96+
97+
return task_done_callback
98+
99+
task.add_done_callback(make_callback(group_id))
100+
active_tasks.add(task)
101+
break
102+
103+
# Brief sleep to avoid busy waiting
104+
await asyncio.sleep(1)
105+
106+
except Exception as e:
107+
print(f"Error in continuous worker loop: {e}")
108+
await asyncio.sleep(1)
109+
110+
if active_tasks:
111+
print(f"Waiting for {len(active_tasks)} continuous tasks to complete...")
112+
await asyncio.wait(active_tasks)
113+
114+
print("Continuous async rollout worker stopped")
115+
116+
def worker_thread_func(self):
117+
"""Worker function running in independent thread"""
118+
asyncio.run(self.continuous_worker_loop())
119+
120+
def start(self):
121+
"""Start continuous work mode"""
122+
if self.worker_thread is None or not self.worker_thread.is_alive():
123+
self.worker_thread = threading.Thread(target=self.worker_thread_func, daemon=True)
124+
self.worker_thread.start()
125+
print("Started continuous async worker thread")
126+
127+
def stop(self):
128+
"""Stop worker thread"""
129+
self.running = False
130+
if self.worker_thread and self.worker_thread.is_alive():
131+
self.worker_thread.join(timeout=5)
132+
print("Stopped async worker thread")
133+
134+
def get_completed_groups(self) -> List[tuple]:
135+
"""Get completed sample groups"""
136+
completed = []
137+
while True:
138+
try:
139+
result = self.output_queue.get_nowait()
140+
completed.append(result)
141+
except queue.Empty:
142+
break
143+
return completed
144+
145+
def get_queue_size(self) -> int:
146+
"""Get current output queue size"""
147+
return self.output_queue.qsize()
148+
149+
150+
async def generate_rollout_async(args, rollout_id: int, data_buffer) -> List[List[Sample]]:
151+
"""
152+
Simplified asynchronous rollout generation - using global continuous worker
153+
"""
154+
assert args.rollout_global_dataset
155+
156+
# Get global worker, which will run continuously
157+
worker = get_global_worker(args, data_buffer)
158+
159+
# Simplified: directly use rollout_batch_size as target
160+
target_data_size = args.rollout_batch_size
161+
162+
data = []
163+
completed_groups = {}
164+
do_print = True
165+
166+
print(f"Starting async rollout generation for {target_data_size} groups")
167+
print(f"Global worker queue size: {worker.get_queue_size()}")
168+
169+
# Main loop: collect results from global worker's output queue
170+
start_time = time.time()
171+
last_progress_time = start_time
172+
no_progress_timeout = 30.0 # Warn if no progress for 30 seconds
173+
174+
while len(data) < target_data_size:
175+
# Collect completed results
176+
completed = worker.get_completed_groups()
177+
178+
made_progress = False
179+
for group_id, group in completed:
180+
completed_groups[group_id] = group
181+
made_progress = True
182+
183+
if made_progress:
184+
last_progress_time = time.time()
185+
186+
# Process completed groups in order (try to maintain order, but not strict requirement)
187+
processed_any = False
188+
189+
# Process all available completed groups
190+
available_ids = list(completed_groups.keys())
191+
for group_id in available_ids:
192+
if len(data) >= target_data_size:
193+
break
194+
195+
group = completed_groups.pop(group_id)
196+
197+
if do_print:
198+
print(
199+
f"First rollout sample: {[group[0].prompt + group[0].response]}, "
200+
f"label: {group[0].label}, reward: {group[0].reward}",
201+
flush=True,
202+
)
203+
do_print = False
204+
205+
# Simplified: directly add samples, no filters used
206+
data.append(group)
207+
processed_any = True
208+
209+
# Check progress
210+
current_time = time.time()
211+
if current_time - last_progress_time > no_progress_timeout:
212+
print(
213+
f"Warning: No progress for {no_progress_timeout}s. "
214+
f"Queue size: {worker.get_queue_size()}, "
215+
f"Collected: {len(data)}/{target_data_size}"
216+
)
217+
last_progress_time = current_time
218+
219+
# If no results were processed, brief sleep to avoid busy waiting
220+
if not processed_any:
221+
await asyncio.sleep(0.01)
222+
223+
duration = time.time() - start_time
224+
print(f"Rollout completed in {duration:.2f}s! Global worker queue size: {worker.get_queue_size()}")
225+
226+
if data:
227+
print(
228+
f"Finish rollout: {[data[-1][0].prompt + data[-1][0].response]}, "
229+
f"label: {data[-1][0].label}, reward: {data[-1][0].reward}",
230+
flush=True,
231+
)
232+
233+
data = sorted(data, key=lambda group: group[0].index)
234+
return data
235+
236+
237+
def generate_rollout_fully_async(args, rollout_id, data_buffer, evaluation=False):
238+
if evaluation:
239+
raise ValueError("Evaluation mode not supported in simple async rollout")
240+
241+
completed_samples = run(generate_rollout_async(args, rollout_id, data_buffer))
242+
return completed_samples
243+
244+
245+
# Register exit cleanup function
246+
import atexit
247+
248+
atexit.register(stop_global_worker)

0 commit comments

Comments
 (0)