|
29 | 29 | dry_run_suite, |
30 | 30 | ) |
31 | 31 |
|
32 | | -DIST_INIT_PORT = 30000 |
| 32 | +DIST_INIT_PORT = 10011 |
33 | 33 | SERVER_PORT = 30000 |
34 | 34 | CONTROL_PORT = 18080 |
35 | 35 |
|
36 | 36 | _control_state = {"done": False, "exit_code": 0} |
37 | 37 |
|
38 | 38 |
|
| 39 | +def _publish_state(exit_code: int) -> None: |
| 40 | + """Set terminal state visible to worker ranks via /status. |
| 41 | +
|
| 42 | + Writes exit_code BEFORE done so any worker that observes ``done == True`` |
| 43 | + is guaranteed to read the matching exit_code in the same poll. Without |
| 44 | + this ordering a worker can race the two assignments and report success |
| 45 | + while rank 0 actually failed. |
| 46 | + """ |
| 47 | + _control_state["exit_code"] = exit_code |
| 48 | + _control_state["done"] = True |
| 49 | + |
| 50 | + |
| 51 | +def _reset_state() -> None: |
| 52 | + _control_state["done"] = False |
| 53 | + _control_state["exit_code"] = 0 |
| 54 | + |
| 55 | + |
39 | 56 | def _log(message: str) -> None: |
40 | 57 | print(f"[multi-host-suite] {message}", flush=True) |
41 | 58 |
|
@@ -223,44 +240,49 @@ def run_model_run(model_run: ModelRun, runtime_cfg: RuntimeConfig) -> int: |
223 | 240 | _log( |
224 | 241 | f"Launching model run={model_run.name}, rank={runtime_cfg.node_rank}, port={runtime_cfg.port}" |
225 | 242 | ) |
226 | | - control_server = start_control_server() if runtime_cfg.node_rank == 0 else None |
227 | | - base_url = f"http://{runtime_cfg.host}:{runtime_cfg.port}" |
228 | | - server_process = popen_launch_server( |
229 | | - model=model_run.model.model_path, |
230 | | - base_url=base_url, |
231 | | - timeout=1800, |
232 | | - other_args=build_other_server_args(model_run.model, runtime_cfg), |
233 | | - ) |
| 243 | + is_rank0 = runtime_cfg.node_rank == 0 |
| 244 | + _reset_state() |
| 245 | + control_server = start_control_server() if is_rank0 else None |
| 246 | + server_process = None |
234 | 247 | exit_code = 0 |
235 | 248 |
|
236 | 249 | try: |
237 | | - if runtime_cfg.node_rank == 0: |
| 250 | + base_url = f"http://{runtime_cfg.host}:{runtime_cfg.port}" |
| 251 | + server_process = popen_launch_server( |
| 252 | + model=model_run.model.model_path, |
| 253 | + base_url=base_url, |
| 254 | + timeout=1800, |
| 255 | + other_args=build_other_server_args(model_run.model, runtime_cfg), |
| 256 | + ) |
| 257 | + |
| 258 | + if is_rank0: |
238 | 259 | for case in model_run.cases: |
239 | 260 | run_case(case, model_run.model.model_path, runtime_cfg.port) |
240 | | - _control_state["done"] = True |
241 | | - _control_state["exit_code"] = 0 |
242 | 261 | else: |
243 | 262 | workload_name = _get_env("WORKLOAD_NAME") |
244 | 263 | headless_service_name = _get_env("HEADLESS_SERVICE_NAME") |
245 | 264 | control_url = f"http://{workload_name}-0.{headless_service_name}:{CONTROL_PORT}/status" |
246 | 265 | exit_code = wait_for_done(control_url, server_process) |
247 | 266 | except Exception: |
248 | 267 | exit_code = 1 |
249 | | - if runtime_cfg.node_rank == 0: |
250 | | - _control_state["done"] = True |
251 | | - _control_state["exit_code"] = exit_code |
252 | 268 | raise |
253 | 269 | finally: |
254 | | - if runtime_cfg.node_rank == 0: |
| 270 | + if is_rank0: |
| 271 | + # Always publish — covers success, case failure, and popen_launch_server |
| 272 | + # failure (where server_process never got assigned). Without this, |
| 273 | + # workers spin on /status until wait_for_done's 60-min timeout when |
| 274 | + # rank 0 dies during launch. |
| 275 | + _publish_state(exit_code) |
255 | 276 | _log("Keeping control server alive for worker ranks") |
256 | 277 | time.sleep(30) |
257 | | - _log("Stopping server process") |
258 | | - kill_process_tree(server_process.pid) |
259 | | - try: |
260 | | - server_process.wait(timeout=5) |
261 | | - except subprocess.TimeoutExpired: |
262 | | - server_process.kill() |
263 | | - server_process.wait() |
| 278 | + if server_process is not None: |
| 279 | + _log("Stopping server process") |
| 280 | + kill_process_tree(server_process.pid) |
| 281 | + try: |
| 282 | + server_process.wait(timeout=5) |
| 283 | + except subprocess.TimeoutExpired: |
| 284 | + server_process.kill() |
| 285 | + server_process.wait() |
264 | 286 | if control_server is not None: |
265 | 287 | control_server.shutdown() |
266 | 288 |
|
|
0 commit comments