|
| 1 | +""" |
| 2 | +This is a file for TorchX components used for testing torchft. |
| 3 | +""" |
| 4 | + |
| 5 | +import os |
| 6 | +from typing import Dict, Optional |
| 7 | + |
| 8 | +import torchx.specs as specs |
| 9 | + |
| 10 | + |
| 11 | +def hsdp( |
| 12 | + *script_args: str, |
| 13 | + replicas: int = 2, |
| 14 | + workers_per_replica: int = 1, |
| 15 | + max_restarts: int = 10, |
| 16 | + script: str = "train_ddp.py", |
| 17 | + env: Optional[Dict[str, str]] = None, |
| 18 | + image: str = "", |
| 19 | + h: Optional[str] = None, |
| 20 | + cpu: int = 2, |
| 21 | + gpu: int = 0, |
| 22 | + memMB: int = 1024, |
| 23 | +) -> specs.AppDef: |
| 24 | + assert replicas > 0, "replicas must be > 0" |
| 25 | + assert workers_per_replica > 0, "workers_per_replica must be > 0" |
| 26 | + |
| 27 | + env = env or {} |
| 28 | + |
| 29 | + # Enable logging for PyTorch, torchelastic and Rust. |
| 30 | + env.setdefault("TORCH_CPP_LOG_LEVEL", "INFO") |
| 31 | + env.setdefault("LOGLEVEL", "INFO") |
| 32 | + env.setdefault("RUST_BACKTRACE", "1") |
| 33 | + |
| 34 | + # Enable colored logging for torchft Rust logger. |
| 35 | + env.setdefault("CLICOLOR_FORCE", "1") |
| 36 | + |
| 37 | + # Set lighthouse address for replicas |
| 38 | + # This must be run externally |
| 39 | + env.setdefault( |
| 40 | + "TORCHFT_LIGHTHOUSE", |
| 41 | + os.environ.get("TORCHFT_LIGHTHOUSE", f"http://localhost:29510"), |
| 42 | + ) |
| 43 | + |
| 44 | + # Disable CUDA for CPU-only jobs |
| 45 | + env.setdefault("CUDA_VISIBLE_DEVICES", "") |
| 46 | + |
| 47 | + roles = [] |
| 48 | + for replica_id in range(replicas): |
| 49 | + cmd = [ |
| 50 | + f"--master_port={29600+replica_id}", |
| 51 | + "--nnodes=1", |
| 52 | + f"--nproc_per_node={workers_per_replica}", |
| 53 | + f"--max_restarts={max_restarts}", |
| 54 | + ] |
| 55 | + if script: |
| 56 | + cmd += [script] |
| 57 | + cmd += list(script_args) |
| 58 | + |
| 59 | + roles.append( |
| 60 | + specs.Role( |
| 61 | + name=f"replica_{replica_id}", |
| 62 | + image=image, |
| 63 | + min_replicas=workers_per_replica, |
| 64 | + num_replicas=workers_per_replica, |
| 65 | + resource=specs.resource(cpu=cpu, gpu=gpu, memMB=memMB, h=h), |
| 66 | + max_retries=0, |
| 67 | + env=env, |
| 68 | + entrypoint="torchrun", |
| 69 | + args=cmd, |
| 70 | + ) |
| 71 | + ) |
| 72 | + |
| 73 | + return specs.AppDef( |
| 74 | + name="torchft", |
| 75 | + roles=roles, |
| 76 | + ) |
0 commit comments