@@ -39,6 +39,7 @@ def execute_train(
3939 train_args : str ,
4040 # TODO rename to "num_gpus_per_node"
4141 num_gpus : int ,
42+ # TODO rename to "megatron_model_type"
4243 model_type : Optional [str ],
4344 train_script : str = "train.py" ,
4445 before_ray_job_submit = None ,
@@ -47,6 +48,9 @@ def execute_train(
4748 external_ray = bool (int (os .environ .get ("SLIME_SCRIPT_EXTERNAL_RAY" , "0" )))
4849 master_addr = os .environ .get ("MASTER_ADDR" , "127.0.0.1" )
4950
51+ train_backend_fsdp = "--train-backend fsdp" in train_args
52+ assert train_backend_fsdp == (model_type is None )
53+
5054 exec_command (
5155 "pkill -9 sglang; "
5256 "sleep 3; "
@@ -78,7 +82,14 @@ def execute_train(
7882 {
7983 "env_vars" : {
8084 "PYTHONPATH" : "/root/Megatron-LM/" ,
81- "CUDA_DEVICE_MAX_CONNECTIONS" : "1" ,
85+ # If setting this in FSDP, the computation communication overlapping may have issues
86+ ** (
87+ {}
88+ if train_backend_fsdp
89+ else {
90+ "CUDA_DEVICE_MAX_CONNECTIONS" : "1" ,
91+ }
92+ ),
8293 "NCCL_NVLS_ENABLE" : str (int (check_has_nvlink ())),
8394 "no_proxy" : f"127.0.0.1,{ master_addr } " ,
8495 # This is needed by megatron / torch distributed in multi-node setup
0 commit comments