|
| 1 | +import datetime |
| 2 | +import json |
| 3 | +import os |
| 4 | +import random |
| 5 | +import subprocess |
| 6 | +from pathlib import Path |
| 7 | + |
| 8 | +repo_base_dir = Path(os.path.abspath(__file__)).resolve().parents[1] |
| 9 | + |
| 10 | + |
| 11 | +def convert_checkpoint(model_name, model_type): |
| 12 | + # TODO shall we make it in host-mapped folder and thus can cache it to speedup CI |
| 13 | + path_dst = f"/root/{model_name}_torch_dist" |
| 14 | + if Path(path_dst).exists(): |
| 15 | + print(f"convert_checkpoint skip {path_dst} since exists") |
| 16 | + return |
| 17 | + |
| 18 | + exec_command( |
| 19 | + f"source {repo_base_dir}/scripts/models/{model_type}.sh && " |
| 20 | + "PYTHONPATH=/root/Megatron-LM torchrun --nproc-per-node 8 tools/convert_hf_to_torch_dist.py " |
| 21 | + "${MODEL_ARGS[@]} " |
| 22 | + f"--hf-checkpoint /root/models/{model_name} " |
| 23 | + f"--save {path_dst}" |
| 24 | + ) |
| 25 | + |
| 26 | + |
| 27 | +def execute_train( |
| 28 | + train_args: str, |
| 29 | + num_gpus: int, |
| 30 | + model_type: str, |
| 31 | + master_addr: str = "127.0.0.1", |
| 32 | +): |
| 33 | + exec_command( |
| 34 | + "pkill -9 sglang; " |
| 35 | + "sleep 3; " |
| 36 | + "ray stop --force; " |
| 37 | + "pkill -9 ray; " |
| 38 | + # cannot be run in CI, o/w kill the parent script |
| 39 | + # TODO: do we really need this kill? (or can we instead kill slime) |
| 40 | + # "pkill -9 python; " |
| 41 | + "pkill -9 slime; " |
| 42 | + "sleep 3; " |
| 43 | + "pkill -9 ray; " |
| 44 | + # "pkill -9 python; " |
| 45 | + "pkill -9 slime; " |
| 46 | + "pkill -9 redis; " |
| 47 | + "true; " |
| 48 | + ) |
| 49 | + |
| 50 | + exec_command( |
| 51 | + # will prevent ray from buffering stdout/stderr |
| 52 | + f"export PYTHONBUFFERED=16 && " |
| 53 | + f"ray start --head --node-ip-address {master_addr} --num-gpus {num_gpus} --disable-usage-stats" |
| 54 | + ) |
| 55 | + |
| 56 | + runtime_env_json = json.dumps( |
| 57 | + { |
| 58 | + "env_vars": { |
| 59 | + "PYTHONPATH": "/root/Megatron-LM/", |
| 60 | + "CUDA_DEVICE_MAX_CONNECTIONS": "1", |
| 61 | + "NCCL_NVLS_ENABLE": str(int(check_has_nvlink())), |
| 62 | + "no_proxy": f"127.0.0.1,{master_addr}", |
| 63 | + } |
| 64 | + } |
| 65 | + ) |
| 66 | + |
| 67 | + exec_command( |
| 68 | + f"export no_proxy=127.0.0.1 && export PYTHONBUFFERED=16 && " |
| 69 | + f'source "{repo_base_dir}/scripts/models/{model_type}.sh" && ' |
| 70 | + # TODO should this 127.0.0.1 be `master_addr` instead |
| 71 | + f'ray job submit --address="http://127.0.0.1:8265" ' |
| 72 | + f"--runtime-env-json='{runtime_env_json}' " |
| 73 | + "-- python3 train.py " |
| 74 | + "${MODEL_ARGS[@]} " |
| 75 | + f"{train_args}" |
| 76 | + ) |
| 77 | + |
| 78 | + |
| 79 | +def check_has_nvlink(): |
| 80 | + output = exec_command("nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l", capture_output=True) |
| 81 | + return int(output) > 0 |
| 82 | + |
| 83 | + |
| 84 | +def get_default_wandb_args(test_file: str): |
| 85 | + if not os.environ.get("WANDB_API_KEY"): |
| 86 | + print("Skip wandb configuration since WANDB_API_KEY is not found") |
| 87 | + return "" |
| 88 | + |
| 89 | + test_name = Path(test_file).stem |
| 90 | + |
| 91 | + run_name = f"{datetime.datetime.now().strftime('%Y%m%d%H%M%S')}-{random.randint(0, 1000000000)}" |
| 92 | + if (x := os.environ.get("GITHUB_COMMIT_NAME")) is not None: |
| 93 | + run_name += f"_{x}" |
| 94 | + |
| 95 | + # do not put wandb_api_key value here to avoid leaking to logs explicitly |
| 96 | + return ( |
| 97 | + "--use-wandb " |
| 98 | + f"--wandb-project slime-ci-{test_name} " |
| 99 | + f"--wandb-group {run_name} " |
| 100 | + f"--wandb-key ${{WANDB_API_KEY}} " |
| 101 | + ) |
| 102 | + |
| 103 | + |
| 104 | +def exec_command(cmd: str, capture_output: bool = False): |
| 105 | + print(f"EXEC: {cmd}", flush=True) |
| 106 | + result = subprocess.run(["bash", "-c", cmd], shell=False, check=True, capture_output=capture_output) |
| 107 | + if capture_output: |
| 108 | + return result.stdout |
0 commit comments