Skip to content

Commit 3a5f601

Browse files
committed
cp
1 parent fb87a0e commit 3a5f601

File tree

1 file changed

+84
-34
lines changed

1 file changed

+84
-34
lines changed

tests/command_utils.py

Lines changed: 84 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,56 +2,73 @@
22
import json
33
import os
44
import random
5-
import subprocess
65
from pathlib import Path
6+
from typing import Optional
7+
from slime.utils.misc import exec_command
8+
9+
_ = exec_command
710

811
repo_base_dir = Path(os.path.abspath(__file__)).resolve().parents[1]
912

1013

11-
def convert_checkpoint(model_name, model_type):
14+
def convert_checkpoint(model_name, model_type, num_gpus: int, dir_dst="/root"):
1215
# TODO shall we make it in host-mapped folder and thus can cache it to speedup CI
13-
path_dst = f"/root/{model_name}_torch_dist"
16+
path_dst = f"{dir_dst}/{model_name}_torch_dist"
1417
if Path(path_dst).exists():
1518
print(f"convert_checkpoint skip {path_dst} since exists")
1619
return
1720

1821
exec_command(
1922
f"source {repo_base_dir}/scripts/models/{model_type}.sh && "
20-
"PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 tools/convert_hf_to_torch_dist.py "
23+
f"PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node {num_gpus} tools/convert_hf_to_torch_dist.py "
2124
"${MODEL_ARGS[@]} "
2225
f"--hf-checkpoint /root/models/{model_name} "
2326
f"--save {path_dst}"
2427
)
2528

2629

30+
def hf_download_dataset(full_name: str):
31+
_, partial_name = full_name.split("/")
32+
exec_command(f"hf download --repo-type dataset {full_name} --local-dir /root/datasets/{partial_name}")
33+
34+
2735
def execute_train(
2836
train_args: str,
2937
num_gpus: int,
30-
model_type: str,
31-
master_addr: str = "127.0.0.1",
38+
model_type: Optional[str],
39+
train_script: str = "train.py",
40+
before_ray_job_submit=None,
41+
extra_env_vars={},
3242
):
43+
external_ray = bool(int(os.environ.get("MILES_SCRIPT_EXTERNAL_RAY", "0")))
44+
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
45+
3346
exec_command(
3447
"pkill -9 sglang; "
3548
"sleep 3; "
36-
"ray stop --force; "
37-
"pkill -9 ray; "
49+
f"{'' if external_ray else 'ray stop --force; '}"
50+
f"{'' if external_ray else 'pkill -9 ray; '}"
3851
# cannot be run in CI, o/w kill the parent script
3952
# TODO: do we really need this kill? (or can we instead kill slime)
4053
# "pkill -9 python; "
4154
"pkill -9 slime; "
4255
"sleep 3; "
43-
"pkill -9 ray; "
56+
f"{'' if external_ray else 'pkill -9 ray; '}"
4457
# "pkill -9 python; "
4558
"pkill -9 slime; "
4659
"pkill -9 redis; "
4760
"true; "
4861
)
4962

50-
exec_command(
51-
# will prevent ray from buffering stdout/stderr
52-
f"export PYTHONBUFFERED=16 && "
53-
f"ray start --head --node-ip-address {master_addr} --num-gpus {num_gpus} --disable-usage-stats"
54-
)
63+
if not external_ray:
64+
exec_command(
65+
# will prevent ray from buffering stdout/stderr
66+
f"export PYTHONBUFFERED=16 && "
67+
f"ray start --head --node-ip-address {master_addr} --num-gpus {num_gpus} --disable-usage-stats"
68+
)
69+
70+
if (f := before_ray_job_submit) is not None:
71+
f()
5572

5673
runtime_env_json = json.dumps(
5774
{
@@ -60,49 +77,82 @@ def execute_train(
6077
"CUDA_DEVICE_MAX_CONNECTIONS": "1",
6178
"NCCL_NVLS_ENABLE": str(int(check_has_nvlink())),
6279
"no_proxy": f"127.0.0.1,{master_addr}",
80+
# This is needed by megatron / torch distributed in multi-node setup
81+
"MASTER_ADDR": master_addr,
82+
**extra_env_vars,
6383
}
6484
}
6585
)
6686

67-
exec_command(
68-
f"export no_proxy=127.0.0.1 && export PYTHONBUFFERED=16 && "
69-
f'source "{repo_base_dir}/scripts/models/{model_type}.sh" && '
70-
# TODO should this 127.0.0.1 be `master_addr` instead
71-
f'ray job submit --address="http://127.0.0.1:8265" '
72-
f"--runtime-env-json='{runtime_env_json}' "
73-
"-- python3 train.py "
74-
"${MODEL_ARGS[@]} "
75-
f"{train_args}"
76-
)
87+
source_cmd = f'source "{repo_base_dir}/scripts/models/{model_type}.sh" && ' if model_type is not None else ""
88+
model_args_str = "${MODEL_ARGS[@]}" if model_type is not None else ""
89+
90+
if bool(int(os.environ.get("MILES_SCRIPT_ENABLE_RAY_SUBMIT", "1"))):
91+
exec_command(
92+
f"export PYTHONBUFFERED=16 && "
93+
f"{source_cmd}"
94+
# TODO should this 127.0.0.1 be `master_addr` instead
95+
f'ray job submit --address="http://127.0.0.1:8265" '
96+
f"--runtime-env-json='{runtime_env_json}' "
97+
f"-- python3 {train_script} "
98+
f"{model_args_str} "
99+
f"{train_args}"
100+
)
77101

78102

79103
def check_has_nvlink():
80104
output = exec_command("nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l", capture_output=True)
81105
return int(output) > 0
82106

83107

84-
def get_default_wandb_args(test_file: str):
108+
def get_default_wandb_args(test_file: str, run_name_prefix: Optional[str] = None, run_id: Optional[str] = None):
85109
if not os.environ.get("WANDB_API_KEY"):
86110
print("Skip wandb configuration since WANDB_API_KEY is not found")
87111
return ""
88112

89-
test_name = Path(test_file).stem
113+
test_file = Path(test_file)
114+
test_name = test_file.stem
115+
if len(test_name) < 6:
116+
test_name = f"{test_file.parent.name}_{test_name}"
90117

91-
run_name = f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{random.randint(0, 1000000000)}"
118+
wandb_run_name = run_id or create_run_id()
92119
if (x := os.environ.get("GITHUB_COMMIT_NAME")) is not None:
93-
run_name += f"_{x}"
120+
wandb_run_name += f"_{x}"
121+
if (x := run_name_prefix) is not None:
122+
wandb_run_name = f"{x}_{wandb_run_name}"
94123

95124
# do not put wandb_api_key value here to avoid leaking to logs explicitly
96125
return (
97126
"--use-wandb "
98127
f"--wandb-project slime-ci-{test_name} "
99-
f"--wandb-group {run_name} "
128+
f"--wandb-group {wandb_run_name} "
100129
f"--wandb-key ${{WANDB_API_KEY}} "
130+
"--disable-wandb-random-suffix "
101131
)
102132

103133

104-
def exec_command(cmd: str, capture_output: bool = False):
105-
print(f"EXEC: {cmd}", flush=True)
106-
result = subprocess.run(["bash", "-c", cmd], shell=False, check=True, capture_output=capture_output)
107-
if capture_output:
108-
return result.stdout
134+
def create_run_id() -> str:
135+
return datetime.datetime.now().strftime("%y%m%d-%H%M%S") + f"-{random.Random().randint(0, 999):03d}"
136+
137+
138+
_warned_bool_env_var_keys = set()
139+
140+
141+
# copied from SGLang
142+
def get_bool_env_var(name: str, default: str = "false") -> bool:
143+
value = os.getenv(name, default)
144+
value = value.lower()
145+
146+
truthy_values = ("true", "1")
147+
falsy_values = ("false", "0")
148+
149+
if (value not in truthy_values) and (value not in falsy_values):
150+
if value not in _warned_bool_env_var_keys:
151+
print(f"get_bool_env_var({name}) see non-understandable value={value} and treat as false")
152+
_warned_bool_env_var_keys.add(value)
153+
154+
return value in truthy_values
155+
156+
157+
def get_env_enable_infinite_run():
158+
return get_bool_env_var("MILES_TEST_ENABLE_INFINITE_RUN", "false")

0 commit comments

Comments
 (0)