22import json
33import os
44import random
5- import subprocess
65from pathlib import Path
6+ from typing import Optional
7+ from slime .utils .misc import exec_command
8+
9+ _ = exec_command
710
811repo_base_dir = Path (os .path .abspath (__file__ )).resolve ().parents [1 ]
912
1013
11- def convert_checkpoint (model_name , model_type ):
14+ def convert_checkpoint (model_name , model_type , num_gpus : int , dir_dst = "/root" ):
1215 # TODO shall we make it in host-mapped folder and thus can cache it to speedup CI
13- path_dst = f"/root /{ model_name } _torch_dist"
16+ path_dst = f"{ dir_dst } /{ model_name } _torch_dist"
1417 if Path (path_dst ).exists ():
1518 print (f"convert_checkpoint skip { path_dst } since exists" )
1619 return
1720
1821 exec_command (
1922 f"source { repo_base_dir } /scripts/models/{ model_type } .sh && "
20- "PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 tools/convert_hf_to_torch_dist.py "
23+ f "PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node { num_gpus } tools/convert_hf_to_torch_dist.py "
2124 "${MODEL_ARGS[@]} "
2225 f"--hf-checkpoint /root/models/{ model_name } "
2326 f"--save { path_dst } "
2427 )
2528
2629
30+ def hf_download_dataset (full_name : str ):
31+ _ , partial_name = full_name .split ("/" )
32+ exec_command (f"hf download --repo-type dataset { full_name } --local-dir /root/datasets/{ partial_name } " )
33+
34+
2735def execute_train (
2836 train_args : str ,
2937 num_gpus : int ,
30- model_type : str ,
31- master_addr : str = "127.0.0.1" ,
38+ model_type : Optional [str ],
39+ train_script : str = "train.py" ,
40+ before_ray_job_submit = None ,
41+ extra_env_vars = {},
3242):
43+ external_ray = bool (int (os .environ .get ("MILES_SCRIPT_EXTERNAL_RAY" , "0" )))
44+ master_addr = os .environ .get ("MASTER_ADDR" , "127.0.0.1" )
45+
3346 exec_command (
3447 "pkill -9 sglang; "
3548 "sleep 3; "
36- " ray stop --force; "
37- " pkill -9 ray; "
49+ f" { '' if external_ray else ' ray stop --force; ' } "
50+ f" { '' if external_ray else ' pkill -9 ray; ' } "
3851 # cannot be run in CI, o/w kill the parent script
3952 # TODO: do we really need this kill? (or can we instead kill slime)
4053 # "pkill -9 python; "
4154 "pkill -9 slime; "
4255 "sleep 3; "
43- " pkill -9 ray; "
56+ f" { '' if external_ray else ' pkill -9 ray; ' } "
4457 # "pkill -9 python; "
4558 "pkill -9 slime; "
4659 "pkill -9 redis; "
4760 "true; "
4861 )
4962
50- exec_command (
51- # will prevent ray from buffering stdout/stderr
52- f"export PYTHONBUFFERED=16 && "
53- f"ray start --head --node-ip-address { master_addr } --num-gpus { num_gpus } --disable-usage-stats"
54- )
63+ if not external_ray :
64+ exec_command (
65+ # will prevent ray from buffering stdout/stderr
66+ f"export PYTHONBUFFERED=16 && "
67+ f"ray start --head --node-ip-address { master_addr } --num-gpus { num_gpus } --disable-usage-stats"
68+ )
69+
70+ if (f := before_ray_job_submit ) is not None :
71+ f ()
5572
5673 runtime_env_json = json .dumps (
5774 {
@@ -60,49 +77,82 @@ def execute_train(
6077 "CUDA_DEVICE_MAX_CONNECTIONS" : "1" ,
6178 "NCCL_NVLS_ENABLE" : str (int (check_has_nvlink ())),
6279 "no_proxy" : f"127.0.0.1,{ master_addr } " ,
80+ # This is needed by megatron / torch distributed in multi-node setup
81+ "MASTER_ADDR" : master_addr ,
82+ ** extra_env_vars ,
6383 }
6484 }
6585 )
6686
67- exec_command (
68- f"export no_proxy=127.0.0.1 && export PYTHONBUFFERED=16 && "
69- f'source "{ repo_base_dir } /scripts/models/{ model_type } .sh" && '
70- # TODO should this 127.0.0.1 be `master_addr` instead
71- f'ray job submit --address="http://127.0.0.1:8265" '
72- f"--runtime-env-json='{ runtime_env_json } ' "
73- "-- python3 train.py "
74- "${MODEL_ARGS[@]} "
75- f"{ train_args } "
76- )
87+ source_cmd = f'source "{ repo_base_dir } /scripts/models/{ model_type } .sh" && ' if model_type is not None else ""
88+ model_args_str = "${MODEL_ARGS[@]}" if model_type is not None else ""
89+
90+ if bool (int (os .environ .get ("MILES_SCRIPT_ENABLE_RAY_SUBMIT" , "1" ))):
91+ exec_command (
92+ f"export PYTHONBUFFERED=16 && "
93+ f"{ source_cmd } "
94+ # TODO should this 127.0.0.1 be `master_addr` instead
95+ f'ray job submit --address="http://127.0.0.1:8265" '
96+ f"--runtime-env-json='{ runtime_env_json } ' "
97+ f"-- python3 { train_script } "
98+ f"{ model_args_str } "
99+ f"{ train_args } "
100+ )
77101
78102
79103def check_has_nvlink ():
80104 output = exec_command ("nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l" , capture_output = True )
81105 return int (output ) > 0
82106
83107
84- def get_default_wandb_args (test_file : str ):
108+ def get_default_wandb_args (test_file : str , run_name_prefix : Optional [ str ] = None , run_id : Optional [ str ] = None ):
85109 if not os .environ .get ("WANDB_API_KEY" ):
86110 print ("Skip wandb configuration since WANDB_API_KEY is not found" )
87111 return ""
88112
89- test_name = Path (test_file ).stem
113+ test_file = Path (test_file )
114+ test_name = test_file .stem
115+ if len (test_name ) < 6 :
116+ test_name = f"{ test_file .parent .name } _{ test_name } "
90117
91- run_name = f" { datetime . datetime . now (). strftime ( '%Y%m%d%H%M%S' ) } - { random . randint ( 0 , 1000000000 ) } "
118+ wandb_run_name = run_id or create_run_id ()
92119 if (x := os .environ .get ("GITHUB_COMMIT_NAME" )) is not None :
93- run_name += f"_{ x } "
120+ wandb_run_name += f"_{ x } "
121+ if (x := run_name_prefix ) is not None :
122+ wandb_run_name = f"{ x } _{ wandb_run_name } "
94123
95124 # do not put wandb_api_key value here to avoid leaking to logs explicitly
96125 return (
97126 "--use-wandb "
98127 f"--wandb-project slime-ci-{ test_name } "
99- f"--wandb-group { run_name } "
128+ f"--wandb-group { wandb_run_name } "
100129 f"--wandb-key ${{WANDB_API_KEY}} "
130+ "--disable-wandb-random-suffix "
101131 )
102132
103133
104- def exec_command (cmd : str , capture_output : bool = False ):
105- print (f"EXEC: { cmd } " , flush = True )
106- result = subprocess .run (["bash" , "-c" , cmd ], shell = False , check = True , capture_output = capture_output )
107- if capture_output :
108- return result .stdout
134+ def create_run_id () -> str :
135+ return datetime .datetime .now ().strftime ("%y%m%d-%H%M%S" ) + f"-{ random .Random ().randint (0 , 999 ):03d} "
136+
137+
138+ _warned_bool_env_var_keys = set ()
139+
140+
141+ # copied from SGLang
142+ def get_bool_env_var (name : str , default : str = "false" ) -> bool :
143+ value = os .getenv (name , default )
144+ value = value .lower ()
145+
146+ truthy_values = ("true" , "1" )
147+ falsy_values = ("false" , "0" )
148+
149+ if (value not in truthy_values ) and (value not in falsy_values ):
150+ if value not in _warned_bool_env_var_keys :
151+ print (f"get_bool_env_var({ name } ) see non-understandable value={ value } and treat as false" )
152+ _warned_bool_env_var_keys .add (value )
153+
154+ return value in truthy_values
155+
156+
157+ def get_env_enable_infinite_run ():
158+ return get_bool_env_var ("MILES_TEST_ENABLE_INFINITE_RUN" , "false" )
0 commit comments