|
| 1 | +import asyncio |
| 2 | +import threading |
| 3 | +import queue |
| 4 | +import time |
| 5 | +from typing import List |
| 6 | + |
| 7 | +from slime.utils.async_utils import run |
| 8 | +from slime.utils.types import Sample |
| 9 | + |
| 10 | +# Import core functions from sglang_rollout directly to avoid code duplication |
| 11 | +from slime.rollout.sglang_rollout import generate_and_rm_group, GenerateState |
| 12 | + |
| 13 | +# Global worker manager |
| 14 | +_global_worker = None |
| 15 | +_worker_lock = threading.Lock() |
| 16 | + |
| 17 | + |
| 18 | +def get_global_worker(args, data_buffer): |
| 19 | + """Get or create global worker""" |
| 20 | + global _global_worker |
| 21 | + with _worker_lock: |
| 22 | + if _global_worker is None or not _global_worker.worker_thread.is_alive(): |
| 23 | + print("Creating new global async worker...") |
| 24 | + _global_worker = AsyncRolloutWorker(args, data_buffer, concurrency=args.sglang_server_concurrency) |
| 25 | + _global_worker.start() |
| 26 | + return _global_worker |
| 27 | + |
| 28 | + |
| 29 | +def stop_global_worker(): |
| 30 | + """Stop global worker""" |
| 31 | + global _global_worker |
| 32 | + with _worker_lock: |
| 33 | + if _global_worker is not None: |
| 34 | + _global_worker.stop() |
| 35 | + _global_worker = None |
| 36 | + |
| 37 | + |
| 38 | +class AsyncRolloutWorker: |
| 39 | + """ |
| 40 | + Simplified asynchronous rollout worker, using threads instead of processes |
| 41 | + Supports continuous running, independent of rollout function lifecycle |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__(self, args, data_buffer, concurrency=10): |
| 45 | + self.args = args |
| 46 | + self.data_buffer = data_buffer # Directly save data_buffer reference |
| 47 | + self.concurrency = concurrency |
| 48 | + self.running = True |
| 49 | + self.output_queue = queue.Queue(maxsize=1000) # Continuous output queue |
| 50 | + self.worker_thread = None |
| 51 | + self.state = GenerateState(args) |
| 52 | + |
| 53 | + async def continuous_worker_loop(self): |
| 54 | + """Continuous work loop - constantly get data from data_buffer and process""" |
| 55 | + print("Continuous async rollout worker started") |
| 56 | + |
| 57 | + active_tasks = set() |
| 58 | + max_concurrent_tasks = self.args.rollout_batch_size |
| 59 | + group_id_counter = 0 |
| 60 | + |
| 61 | + while self.running: |
| 62 | + try: |
| 63 | + # Clean up completed tasks |
| 64 | + if active_tasks: |
| 65 | + done_tasks = {task for task in active_tasks if task.done()} |
| 66 | + for task in done_tasks: |
| 67 | + try: |
| 68 | + task.result() # Results are already handled in callbacks |
| 69 | + except Exception as e: |
| 70 | + print(f"Task failed with exception: {e}") |
| 71 | + active_tasks -= done_tasks |
| 72 | + |
| 73 | + # If active task count hasn't reached limit, try to get new data and start tasks |
| 74 | + while len(active_tasks) < max_concurrent_tasks and self.running: |
| 75 | + samples = self.data_buffer.get_samples(1) |
| 76 | + |
| 77 | + for group in samples: |
| 78 | + group_id = group_id_counter |
| 79 | + group_id_counter += 1 |
| 80 | + |
| 81 | + # Create new async task |
| 82 | + task = asyncio.create_task( |
| 83 | + generate_and_rm_group( |
| 84 | + self.args, |
| 85 | + group, |
| 86 | + sampling_params=self.state.sampling_params.copy(), |
| 87 | + evaluation=False, |
| 88 | + ) |
| 89 | + ) |
| 90 | + |
| 91 | + # Add completion callback |
| 92 | + def make_callback(gid): |
| 93 | + def task_done_callback(task): |
| 94 | + result = task.result() |
| 95 | + self.output_queue.put((gid, result)) |
| 96 | + |
| 97 | + return task_done_callback |
| 98 | + |
| 99 | + task.add_done_callback(make_callback(group_id)) |
| 100 | + active_tasks.add(task) |
| 101 | + break |
| 102 | + |
| 103 | + # Brief sleep to avoid busy waiting |
| 104 | + await asyncio.sleep(1) |
| 105 | + |
| 106 | + except Exception as e: |
| 107 | + print(f"Error in continuous worker loop: {e}") |
| 108 | + await asyncio.sleep(1) |
| 109 | + |
| 110 | + if active_tasks: |
| 111 | + print(f"Waiting for {len(active_tasks)} continuous tasks to complete...") |
| 112 | + await asyncio.wait(active_tasks) |
| 113 | + |
| 114 | + print("Continuous async rollout worker stopped") |
| 115 | + |
| 116 | + def worker_thread_func(self): |
| 117 | + """Worker function running in independent thread""" |
| 118 | + asyncio.run(self.continuous_worker_loop()) |
| 119 | + |
| 120 | + def start(self): |
| 121 | + """Start continuous work mode""" |
| 122 | + if self.worker_thread is None or not self.worker_thread.is_alive(): |
| 123 | + self.worker_thread = threading.Thread(target=self.worker_thread_func, daemon=True) |
| 124 | + self.worker_thread.start() |
| 125 | + print("Started continuous async worker thread") |
| 126 | + |
| 127 | + def stop(self): |
| 128 | + """Stop worker thread""" |
| 129 | + self.running = False |
| 130 | + if self.worker_thread and self.worker_thread.is_alive(): |
| 131 | + self.worker_thread.join(timeout=5) |
| 132 | + print("Stopped async worker thread") |
| 133 | + |
| 134 | + def get_completed_groups(self) -> List[tuple]: |
| 135 | + """Get completed sample groups""" |
| 136 | + completed = [] |
| 137 | + while True: |
| 138 | + try: |
| 139 | + result = self.output_queue.get_nowait() |
| 140 | + completed.append(result) |
| 141 | + except queue.Empty: |
| 142 | + break |
| 143 | + return completed |
| 144 | + |
| 145 | + def get_queue_size(self) -> int: |
| 146 | + """Get current output queue size""" |
| 147 | + return self.output_queue.qsize() |
| 148 | + |
| 149 | + |
| 150 | +async def generate_rollout_async(args, rollout_id: int, data_buffer) -> List[List[Sample]]: |
| 151 | + """ |
| 152 | + Simplified asynchronous rollout generation - using global continuous worker |
| 153 | + """ |
| 154 | + assert args.rollout_global_dataset |
| 155 | + |
| 156 | + # Get global worker, which will run continuously |
| 157 | + worker = get_global_worker(args, data_buffer) |
| 158 | + |
| 159 | + # Simplified: directly use rollout_batch_size as target |
| 160 | + target_data_size = args.rollout_batch_size |
| 161 | + |
| 162 | + data = [] |
| 163 | + completed_groups = {} |
| 164 | + do_print = True |
| 165 | + |
| 166 | + print(f"Starting async rollout generation for {target_data_size} groups") |
| 167 | + print(f"Global worker queue size: {worker.get_queue_size()}") |
| 168 | + |
| 169 | + # Main loop: collect results from global worker's output queue |
| 170 | + start_time = time.time() |
| 171 | + last_progress_time = start_time |
| 172 | + no_progress_timeout = 30.0 # Warn if no progress for 30 seconds |
| 173 | + |
| 174 | + while len(data) < target_data_size: |
| 175 | + # Collect completed results |
| 176 | + completed = worker.get_completed_groups() |
| 177 | + |
| 178 | + made_progress = False |
| 179 | + for group_id, group in completed: |
| 180 | + completed_groups[group_id] = group |
| 181 | + made_progress = True |
| 182 | + |
| 183 | + if made_progress: |
| 184 | + last_progress_time = time.time() |
| 185 | + |
| 186 | + # Process completed groups in order (try to maintain order, but not strict requirement) |
| 187 | + processed_any = False |
| 188 | + |
| 189 | + # Process all available completed groups |
| 190 | + available_ids = list(completed_groups.keys()) |
| 191 | + for group_id in available_ids: |
| 192 | + if len(data) >= target_data_size: |
| 193 | + break |
| 194 | + |
| 195 | + group = completed_groups.pop(group_id) |
| 196 | + |
| 197 | + if do_print: |
| 198 | + print( |
| 199 | + f"First rollout sample: {[group[0].prompt + group[0].response]}, " |
| 200 | + f"label: {group[0].label}, reward: {group[0].reward}", |
| 201 | + flush=True, |
| 202 | + ) |
| 203 | + do_print = False |
| 204 | + |
| 205 | + # Simplified: directly add samples, no filters used |
| 206 | + data.append(group) |
| 207 | + processed_any = True |
| 208 | + |
| 209 | + # Check progress |
| 210 | + current_time = time.time() |
| 211 | + if current_time - last_progress_time > no_progress_timeout: |
| 212 | + print( |
| 213 | + f"Warning: No progress for {no_progress_timeout}s. " |
| 214 | + f"Queue size: {worker.get_queue_size()}, " |
| 215 | + f"Collected: {len(data)}/{target_data_size}" |
| 216 | + ) |
| 217 | + last_progress_time = current_time |
| 218 | + |
| 219 | + # If no results were processed, brief sleep to avoid busy waiting |
| 220 | + if not processed_any: |
| 221 | + await asyncio.sleep(0.01) |
| 222 | + |
| 223 | + duration = time.time() - start_time |
| 224 | + print(f"Rollout completed in {duration:.2f}s! Global worker queue size: {worker.get_queue_size()}") |
| 225 | + |
| 226 | + if data: |
| 227 | + print( |
| 228 | + f"Finish rollout: {[data[-1][0].prompt + data[-1][0].response]}, " |
| 229 | + f"label: {data[-1][0].label}, reward: {data[-1][0].reward}", |
| 230 | + flush=True, |
| 231 | + ) |
| 232 | + |
| 233 | + data = sorted(data, key=lambda group: group[0].index) |
| 234 | + return data |
| 235 | + |
| 236 | + |
| 237 | +def generate_rollout_fully_async(args, rollout_id, data_buffer, evaluation=False): |
| 238 | + if evaluation: |
| 239 | + raise ValueError("Evaluation mode not supported in simple async rollout") |
| 240 | + |
| 241 | + completed_samples = run(generate_rollout_async(args, rollout_id, data_buffer)) |
| 242 | + return completed_samples |
| 243 | + |
| 244 | + |
| 245 | +# Register exit cleanup function |
| 246 | +import atexit |
| 247 | + |
| 248 | +atexit.register(stop_global_worker) |
0 commit comments