Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions csub.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def build_parser() -> argparse.ArgumentParser:
parser.add_argument("-i", "--image", type=str, help="Override RUNAI_IMAGE from the env file")
parser.add_argument("-p", "--port", type=int, help="Expose a container port")
parser.add_argument("--train", action="store_true", help="Submit as a training workload")
parser.add_argument("--distributed", action="store_true", help="Submit a distributed workload")
parser.add_argument("--workers", default=0, type=int, help="Only read for distributed workloads. Number of nodes IN ADDITION to the master node. I.e., the total number of nodes is the number of workers + 1 (the master node)")
parser.add_argument("--workers", default=0, type=int, help="Number of nodes IN ADDITION to the master node. I.e., the total number of nodes is the number of workers + 1 (the master node)")
parser.add_argument("--dry", action="store_true", help="Print the generated runai command")
parser.add_argument("--env-file", type=str, default=DEFAULT_ENV_FILE, help="Path to the .env file (default: .env in the repo root)")
parser.add_argument("--sync-secret-only", action="store_true", help="Create/refresh the Kubernetes secret and exit without submitting a job")
Expand All @@ -66,8 +65,10 @@ def build_parser() -> argparse.ArgumentParser:
def build_runai_command(
args: argparse.Namespace, env: Dict[str, str]
) -> Tuple[List[str], str]:
assert args.train + args.distributed <= 1, "Choose --train or --distributed but not both"

distributed = args.workers > 0
if not args.train and distributed:
args.train |= distributed
print("Forcing non-interactive as distributed")
job_name = (
args.name
or f"{env['LDAP_USERNAME']}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
Expand Down Expand Up @@ -130,7 +131,7 @@ def build_runai_command(
shell_command += f" && {user_command}"

cmd: List[str] = ["runai"]
cmd.extend(["submit-dist", "pytorch"] if args.distributed else ["submit"])
cmd.extend(["submit-dist", "pytorch"] if distributed else ["submit"])
cmd.extend([
"--name",
job_name,
Expand Down Expand Up @@ -158,7 +159,7 @@ def build_runai_command(
if args.memory:
cmd.extend(["--memory", args.memory])

if not args.train and not args.distributed:
if not args.train:
cmd.append("--interactive")
else:
cmd.extend(["--backoff-limit", str(args.backofflimit)])
Expand All @@ -173,10 +174,10 @@ def build_runai_command(

if args.node_type:
cmd.extend(["--node-pools", args.node_type])
if args.node_type in {"h200", "h100"} and not args.train and not args.distributed:
if args.node_type in {"h200", "h100"} and not args.train:
cmd.append("--preemptible")

if args.distributed:
if distributed:
cmd.extend([
"--workers", str(args.workers),
"--annotation", "k8s.v1.cni.cncf.io/networks=kube-system/roce",
Expand Down