2424from torchtitan .config import ConfigManager , JobConfig
2525from torchtitan .tools .logging import init_logger , logger
2626from torchtitan .train import Trainer
27+ from utils .failure import Failure , FailureActor , FailureController
2728
2829
2930# ==== Allocation boilerplate - much of this will be upstreamed into Monarch ====
3031class MonarchSlurm :
3132 # Cluster Configuration - update these values for your specific cluster
32- machine : str = "aws_g5.12xlarge "
33- machine_memory : int = 186777
33+ machine : str = "gpu.xlarge "
34+ machine_memory : int = 2062607
3435 job_name_prefix : str = "monarch-torchft"
3536
36- job_handles : Dict [str , str ] = {}
37+ def __init__ (self ):
38+ self .job_handles : Dict [str , str ] = {}
39+ atexit .register (self .kill_jobs )
3740
38- @classmethod
39- def get_config (cls , mesh_name : str , nodes_per_mesh : int ) -> Config :
41+ def get_config (self , mesh_name : str , nodes_per_mesh : int ) -> Config :
4042 mesh = [f"{ mesh_name } :{ nodes_per_mesh } :{ MonarchSlurm .machine } " ]
41- appdef = hyperactor .host_mesh (meshes = mesh )
43+ # to enable relative import of utils on actors
44+ current_dir = os .path .dirname (os .path .abspath (__file__ ))
45+ env = {"PYTHONPATH" : current_dir }
46+
47+ appdef = hyperactor .host_mesh (meshes = mesh , env = env )
4248
4349 for role in appdef .roles :
4450 role .resource .memMB = MonarchSlurm .machine_memory
4551
4652 return Config (scheduler = "slurm" , appdef = appdef )
4753
48- @classmethod
49- async def get_or_create_job (cls , mesh_name : str , nodes_per_mesh : int = 1 ) -> None :
50- config = cls .get_config (mesh_name , nodes_per_mesh )
54+ async def get_or_create_job (self , mesh_name : str , nodes_per_mesh : int = 1 ) -> None :
55+ config = self .get_config (mesh_name , nodes_per_mesh )
5156 job_name = f"{ MonarchSlurm .job_name_prefix } -{ mesh_name } "
5257 server_spec = await commands .get_or_create (job_name , config , force_restart = True )
53- cls .job_handles [mesh_name ] = server_spec .name
58+ self .job_handles [mesh_name ] = server_spec .name
5459
55- @classmethod
56- def kill_jobs (cls ):
57- for mesh_name , job_handle in cls .job_handles .items ():
58- try :
59- logger .info (f"Destroying job for mesh { mesh_name } " )
60- commands .kill (f"slurm:///{ job_handle } " )
61- except Exception as e :
62- logger .warning (f"Failed to destroy job for { mesh_name } : { e } " )
60+ def kill_jobs (self ):
61+ for mesh_name in self .job_handles .keys ():
62+ self .kill_job (mesh_name )
63+
64+ def kill_job (self , mesh_name : str ):
65+ try :
66+ job_handle = self .job_handles [mesh_name ]
67+ logger .info (f"Destroying job for mesh { mesh_name } " )
68+ commands .kill (f"slurm:///{ job_handle } " )
69+ except Exception as e :
70+ logger .warning (f"Failed to destroy job for { mesh_name } : { e } " )
6371
64- @classmethod
6572 def proc_mesh (
66- cls ,
73+ self ,
6774 mesh_name : str ,
6875 num_hosts : int = 1 ,
6976 num_gpus : int = 8 ,
7077 ) -> ProcMesh :
7178 allocator = RemoteAllocator (
7279 world_id = MonarchSlurm .job_name_prefix ,
7380 initializer = TorchXRemoteAllocInitializer (
74- f"slurm:///{ cls .job_handles [mesh_name ]} "
81+ f"slurm:///{ self .job_handles [mesh_name ]} "
7582 ),
7683 )
7784 alloc = allocator .allocate (
@@ -94,7 +101,7 @@ def start_lighthouse(self) -> str:
94101 from torchft .coordination import LighthouseServer
95102
96103 self .lighthouse = LighthouseServer (
97- bind = "[::]:0" , min_replicas = 1 , join_timeout_ms = 10000
104+ bind = "[::]:0" , min_replicas = 1 , join_timeout_ms = 60000
98105 )
99106 return self .lighthouse .address ()
100107
@@ -140,6 +147,7 @@ class JobSpec:
140147 replica_count : int
141148 hosts_per_replica : int
142149 gpus_per_node : int
150+ with_failures : bool
143151 lighthouse_address : str = ""
144152
145153
@@ -154,16 +162,15 @@ class Replica:
154162# This does not currently benefit from being an actor, but will once
155163# Monarch supervision APIs are fleshed out.
156164class ReplicaActor (Actor ):
157- def __init__ (
158- self ,
159- spec : JobSpec ,
160- replica_id : int ,
161- ) -> None :
165+ def __init__ (self , spec : JobSpec , replica_id : int , scheduler : MonarchSlurm ) -> None :
162166 self .spec = deepcopy (spec )
163167 self .replica_id = replica_id
164168
165169 self .uid = f"[replica_{ replica_id } ]"
166170 self .spec .job_config .fault_tolerance .replica_id = self .replica_id
171+ self .scheduler = scheduler
172+
173+ self .failure_actors : FailureActor | None = None
167174
168175 @endpoint
169176 async def start_replica (self ) -> None :
@@ -172,14 +179,12 @@ async def start_replica(self) -> None:
172179
173180 trainers_proc_mesh : ProcMesh | None = None
174181 try :
175- trainers_proc_mesh = MonarchSlurm .proc_mesh (
182+ trainers_proc_mesh = self . scheduler .proc_mesh (
176183 f"replica_{ self .replica_id } " ,
177184 self .spec .hosts_per_replica ,
178185 self .spec .gpus_per_node ,
179186 )
180- await trainers_proc_mesh .logging_option (
181- stream_to_client = True , aggregate_window_sec = None
182- )
187+ await trainers_proc_mesh .logging_option (stream_to_client = True )
183188 await setup_env_for_distributed (trainers_proc_mesh )
184189
185190 training_actors = trainers_proc_mesh .spawn (
@@ -189,6 +194,10 @@ async def start_replica(self) -> None:
189194 self .replica_id ,
190195 )
191196
197+ self .failure_actors = trainers_proc_mesh .spawn (
198+ "failure_actors" , FailureActor
199+ )
200+
192201 logger .info (f"{ self .uid } Starting trainers" )
193202 await training_actors .start_training .call (self .spec .lighthouse_address )
194203 await trainers_proc_mesh .stop ()
@@ -197,13 +206,29 @@ async def start_replica(self) -> None:
197206 await trainers_proc_mesh .stop ()
198207 raise e
199208
209+ @endpoint
210+ async def inject_failure (self , failure_type : Failure ):
211+ if self .failure_actors :
212+ try :
213+ logger .info (
214+ f"{ self .uid } Injecting failure ({ failure_type } ) into random trainer"
215+ )
216+
217+ await self .failure_actors .fail .choose (failure_type )
218+ except Exception as e :
219+ error_msg = f"{ self .uid } Injected failure: { e } "
220+ logger .error (error_msg )
221+ else :
222+ error_msg = f"{ self .uid } No failure actors available"
223+ logger .error (error_msg )
224+
200225
201226# delay before re-creating proc mesh on existing job. change as needed.
202- PROC_ATTEMPT_DELAY = 10
227+ PROC_ATTEMPT_DELAY = 0
203228# proc attempts before getting a new scheduler allocation. change as needed.
204- PROC_ATTEMPTS = 2
229+ PROC_ATTEMPTS = 4
205230# attempts before failing training on replica. change as needed.
206- MAX_ATTEMPT = PROC_ATTEMPTS * 2
231+ MAX_ATTEMPT = PROC_ATTEMPTS * 4
207232
208233
209234class OrchestrationManager :
@@ -213,32 +238,41 @@ def __init__(self, spec: JobSpec) -> None:
213238 self .lighthouse_actor : LighthouseActor | None = None
214239 self .lighthouse_mesh : ProcMesh | None = None
215240
241+ self .scheduler = MonarchSlurm ()
242+
216243 async def start_training (self ) -> None :
217244 logger .info (
218245 f"[Controller] Creating training system with { self .spec .replica_count } replicas"
219246 )
220247
221248 for replica_id in range (self .spec .replica_count ):
222- await MonarchSlurm .get_or_create_job (
249+ await self . scheduler .get_or_create_job (
223250 f"replica_{ replica_id } " , self .spec .hosts_per_replica
224251 )
225252
226253 mesh_futures = {}
227254 for i in range (self .spec .replica_count ):
228255 mesh_futures [i ] = asyncio .create_task (self ._run_replica (i , 0 ))
229256
257+ failure_future = None
258+ if self .spec .with_failures :
259+ failure_future = asyncio .create_task (
260+ FailureController .execute_failures (self .replicas , self .scheduler )
261+ )
262+
230263 await asyncio .gather (* mesh_futures .values (), return_exceptions = True )
231264
265+ if failure_future :
266+ failure_future .cancel ()
267+
232268 async def start_lighthouse (self ) -> None :
233269 if self .spec .remote_lighthouse :
234- await MonarchSlurm .get_or_create_job ("lighthouse" )
235- self .lighthouse_mesh = MonarchSlurm .proc_mesh ("lighthouse" , num_gpus = 1 )
270+ await self . scheduler .get_or_create_job ("lighthouse" )
271+ self .lighthouse_mesh = self . scheduler .proc_mesh ("lighthouse" , num_gpus = 1 )
236272 else :
237273 self .lighthouse_mesh = this_host ().spawn_procs ({"gpus" : 1 })
238274
239- await self .lighthouse_mesh .logging_option (
240- stream_to_client = True , aggregate_window_sec = None
241- )
275+ await self .lighthouse_mesh .logging_option (stream_to_client = True )
242276 self .lighthouse_actor = self .lighthouse_mesh .spawn (
243277 "lighthouse_actor" , LighthouseActor
244278 )
@@ -274,7 +308,8 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
274308 logger .info (
275309 f"[Controller] Replica { replica_id } has failed { attempt_number } times. Getting new allocation."
276310 )
277- await MonarchSlurm .get_or_create_job (
311+ self .scheduler .kill_job (f"replica_{ replica_id } " )
312+ await self .scheduler .get_or_create_job (
278313 f"replica_{ replica_id } " , self .spec .hosts_per_replica
279314 )
280315 delay = 0 if not attempt_number else PROC_ATTEMPT_DELAY
@@ -287,10 +322,7 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
287322 await replica_proc_mesh .logging_option (aggregate_window_sec = None )
288323
289324 replica_actor = replica_proc_mesh .spawn (
290- "replica_actor" ,
291- ReplicaActor ,
292- self .spec ,
293- replica_id ,
325+ "replica_actor" , ReplicaActor , self .spec , replica_id , self .scheduler
294326 )
295327
296328 replica = Replica (replica_id , replica_proc_mesh , replica_actor , attempt_number )
@@ -301,8 +333,8 @@ async def _teardown(self, replica_id: int) -> None:
301333 try :
302334 replica = self .replicas [replica_id ]
303335 await replica .proc_mesh .stop ()
304- del replica .proc_mesh
305336 del self .replicas [replica_id ]
337+ del replica .proc_mesh
306338 except Exception as e :
307339 logger .error (f"[Controller] Failed to _teardown replica { replica_id } : { e } " )
308340
@@ -339,20 +371,25 @@ def parse_args() -> argparse.Namespace:
339371 parser .add_argument (
340372 "--model-config" ,
341373 type = str ,
342- default = os . path . join ( script_dir , "debug_model.toml" ) ,
343- help = f"Path to model configuration file (default: { os .path .join (script_dir , 'debug_model.toml' )} )" ,
374+ default = "debug_model.toml" ,
375+ help = f"Relative path to model configuration file (default: { os .path .join (script_dir , 'debug_model.toml' )} )" ,
344376 )
345377 parser .add_argument (
346378 "--dataset-path" ,
347379 type = str ,
348- default = os . path . join ( script_dir , "c4_test" ) ,
349- help = f"Path to training dataset (default: { os .path .join (script_dir , 'c4_test' )} )" ,
380+ default = "c4_test" ,
381+ help = f"Relative path to training dataset (default: { os .path .join (script_dir , 'c4_test' )} )" ,
350382 )
351383 parser .add_argument (
352384 "--tokenizer-path" ,
353385 type = str ,
354- default = os .path .join (script_dir , "tokenizer" ),
355- help = f"Path to tokenizer (default: { os .path .join (script_dir , 'tokenizer' )} )" ,
386+ default = "debug_tokenizer" ,
387+ help = f"Relative path to tokenizer (default: { os .path .join (script_dir , 'debug_tokenizer' )} )" ,
388+ )
389+ parser .add_argument (
390+ "--with-failures" ,
391+ action = "store_true" ,
392+ help = "Enable the failure injector utility (default: False)" ,
356393 )
357394
358395 return parser .parse_args ()
@@ -362,32 +399,37 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
362399 data_parallel_shard_degree = args .gpu_per_node * args .host_per_replica
363400
364401 output_path = "./outputs"
365- training_dataset = "c4_test"
402+ training_dataset = args . dataset_path . split ( "/" )[ - 1 ]
366403
404+ script_dir = os .path .dirname (os .path .abspath (__file__ ))
367405 default_args = [
368406 "--job.config_file" ,
369- args .model_config ,
407+ os . path . join ( script_dir , args .model_config ) ,
370408 "--model.tokenizer_path" ,
371- args .tokenizer_path ,
409+ os . path . join ( script_dir , args .tokenizer_path ) ,
372410 "--comm.trace_buf_size" ,
373411 "0" ,
374412 "--metrics.log_freq" ,
375413 "1" ,
376414 "--fault_tolerance.enable" ,
377415 "--fault_tolerance.group_size" ,
378416 str (args .replica_count ),
417+ "--fault_tolerance.process_group" ,
418+ "nccl" ,
419+ "--fault_tolerance.process_group_timeout_ms" ,
420+ "60000" ,
379421 "--parallelism.data_parallel_shard_degree" ,
380422 str (data_parallel_shard_degree ),
381423 "--activation_checkpoint.mode" ,
382424 "full" ,
383425 "--comm.train_timeout_seconds" ,
384- "60 " ,
426+ "300 " ,
385427 "--training.steps" ,
386428 str (args .training_steps ),
387429 "--training.dataset" ,
388430 training_dataset ,
389431 "--training.dataset_path" ,
390- args .dataset_path ,
432+ os . path . join ( script_dir , args .dataset_path ) ,
391433 "--job.dump_folder" ,
392434 output_path ,
393435 "--metrics.enable_tensorboard" ,
@@ -402,6 +444,7 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
402444 replica_count = args .replica_count ,
403445 hosts_per_replica = args .host_per_replica ,
404446 gpus_per_node = args .gpu_per_node ,
447+ with_failures = args .with_failures ,
405448 )
406449
407450
@@ -414,7 +457,6 @@ async def main() -> None:
414457 args = parse_args ()
415458 job_spec = make_job_spec (args )
416459
417- atexit .register (MonarchSlurm .kill_jobs )
418460 orchestrator = OrchestrationManager (job_spec )
419461 try :
420462 await orchestrator .start_lighthouse ()
0 commit comments