Skip to content

Commit 73fc400

Browse files
committed
multiflex start
1 parent 9133085 commit 73fc400

2 files changed

Lines changed: 165 additions & 17 deletions

File tree

bin/_startup.sh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,35 @@ GROUPING_TRAINER_BUCKET=$(curl -fsS -H "Metadata-Flavor: Google" \
1616
export GROUPING_TRAINER_BUCKET
1717

1818

19+
# ----------------------------------------------------------------------------------------------------------------------
20+
# Multi-flex-start lock race
21+
# ----------------------------------------------------------------------------------------------------------------------
22+
# When gt.launch fans out flex-start submits across multiple zones (--multi_flex_start), every sibling VM gets the same
23+
# `launch-id` metadata. The first to reach this point claims the GCS object atomically via --if-generation-match=0;
24+
# losers see a 412 and self-delete.
25+
#
26+
# Any non-zero `gcloud storage cp` (412 race-loss or a transient gcloud error) self-deletes. Prefer over-deleting (user
27+
# needs to retry the launch) to under-deleting (bunch of colliding work).
28+
LAUNCH_ID=$(curl -fsS -H "Metadata-Flavor: Google" \
29+
http://metadata.google.internal/computeMetadata/v1/instance/attributes/launch-id 2>/dev/null || true)
30+
if [ -n "$LAUNCH_ID" ]; then
31+
LOCK_PATH="gs://${GROUPING_TRAINER_BUCKET}/launches/${LAUNCH_ID}/winner"
32+
if hostname | gcloud storage cp - "$LOCK_PATH" --if-generation-match=0; then
33+
echo "Congratulations! You won the race to $LOCK_PATH"
34+
else
35+
echo "You lost the race to $LOCK_PATH. Self-deleting. Better luck next time"
36+
# || true so a metadata-server hiccup doesn't trip `set -e` before we get to `gcloud compute instances delete`.
37+
NAME=$(curl -fsS -H 'Metadata-Flavor: Google' \
38+
http://metadata.google.internal/computeMetadata/v1/instance/name) || NAME=$(hostname)
39+
ZONE=$(curl -fsS -H 'Metadata-Flavor: Google' \
40+
http://metadata.google.internal/computeMetadata/v1/instance/zone 2>/dev/null \
41+
| awk -F/ '{print $NF}') || ZONE=""
42+
gcloud compute instances delete "$NAME" --zone="$ZONE" --quiet
43+
exit 0
44+
fi
45+
fi
46+
47+
1948
# ----------------------------------------------------------------------------------------------------------------------
2049
# Set up python env
2150
# ----------------------------------------------------------------------------------------------------------------------

src/grouping_trainer/launch.py

Lines changed: 136 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@
2020
import sys
2121
import tempfile
2222
import time
23+
import uuid
2324
from datetime import datetime
2425
from enum import StrEnum
25-
from typing import Literal
26+
from typing import Literal, NoReturn
2627

2728
from pydantic import BaseModel, ConfigDict
2829
from tap import tapify
2930

31+
from grouping_trainer.logging import configure_logging
32+
3033
logger = logging.getLogger(__name__)
3134

3235
_REMOTE_ENV_VAR = "GROUPING_TRAINER_REMOTE"
@@ -87,9 +90,11 @@ class GpuConfig(BaseModel):
8790
reservation_affinity: Literal["none", "any"]
8891
is_for_training: bool
8992
n_gpu: int
93+
boot_disk_type: str = "pd-balanced"
94+
"""Boot disk type. A3 Ultra (a3-ultragpu-8g) rejects pd-balanced and needs hyperdisk-balanced."""
9095

9196

92-
# gcloud compute accelerator-types list --filter="name=nvidia-l4" --format='value(zone)'
97+
# gcloud compute accelerator-types list --filter="name=nvidia-l4" --format='value(zone)' | sort -u
9398
L4_ZONES = (
9499
"us-central1-a",
95100
"us-central1-b",
@@ -168,6 +173,18 @@ class GpuConfig(BaseModel):
168173
"us-east4-c",
169174
"europe-west4-a",
170175
)
176+
# gcloud compute accelerator-types list --filter="name=nvidia-h200-141gb" --format='value(zone)'
177+
H200_ZONES = (
178+
"us-central1-b",
179+
"us-east4-b",
180+
"us-east5-a",
181+
"us-south1-b",
182+
"us-west1-c",
183+
"asia-south1-b",
184+
"asia-south2-c",
185+
"europe-west1-b",
186+
"europe-west4-a",
187+
)
171188

172189
GpuType = Literal[
173190
"l4",
@@ -177,6 +194,7 @@ class GpuConfig(BaseModel):
177194
"a100",
178195
"a100-ddp-2",
179196
"a100-ddp-4",
197+
"h200-ddp-8",
180198
]
181199

182200
gpu_type_to_config: dict[GpuType, GpuConfig] = {
@@ -257,6 +275,18 @@ class GpuConfig(BaseModel):
257275
is_for_training=True,
258276
n_gpu=4,
259277
),
278+
"h200-ddp-8": GpuConfig(
279+
flex_start_zone="us-central1-a",
280+
standard_zones=H200_ZONES,
281+
machine_type="a3-ultragpu-8g",
282+
accelerator=None,
283+
max_run_duration="172800s",
284+
install_nvidia_driver=True,
285+
reservation_affinity="none",
286+
is_for_training=True,
287+
n_gpu=8,
288+
boot_disk_type="hyperdisk-balanced",
289+
),
260290
}
261291

262292

@@ -314,6 +344,17 @@ def _is_stockout(stderr: str) -> bool:
314344
return "ZONE_RESOURCE_POOL_EXHAUSTED" in stderr
315345

316346

347+
def _raise_gce_create_failure(result: subprocess.CompletedProcess[str], args: list[str]) -> NoReturn:
348+
"""Log the gcloud stderr (CalledProcessError's repr drops it) and raise."""
349+
logger.error(f"gcloud failed (exit {result.returncode}):\n{result.stderr}")
350+
raise subprocess.CalledProcessError(
351+
result.returncode,
352+
args,
353+
output=result.stdout,
354+
stderr=result.stderr,
355+
)
356+
357+
317358
def _gce_create_cmd(
318359
config: GpuConfig,
319360
instance_name: str,
@@ -322,11 +363,14 @@ def _gce_create_cmd(
322363
provisioning_model: Literal["FLEX_START", "STANDARD"],
323364
wait_for_instance_creation: bool,
324365
path_to_metadata_script: dict[str, str],
366+
launch_id: str | None = None,
325367
) -> list[str]:
326368
metadata = {"gcs-bucket": os.environ["GROUPING_TRAINER_BUCKET"]}
327369
if config.install_nvidia_driver:
328370
metadata["enable-osconfig"] = "TRUE"
329371
metadata["install-nvidia-driver"] = "True"
372+
if launch_id is not None:
373+
metadata["launch-id"] = launch_id
330374

331375
args = [
332376
"gcloud",
@@ -346,7 +390,7 @@ def _gce_create_cmd(
346390
"--scopes=https://www.googleapis.com/auth/cloud-platform",
347391
(
348392
f"--create-disk=auto-delete=yes,boot=yes,device-name={instance_name},"
349-
f"image={_IMAGE},mode=rw,size=200,type=pd-balanced"
393+
f"image={_IMAGE},mode=rw,size=200,type={config.boot_disk_type}"
350394
),
351395
"--no-shielded-secure-boot",
352396
"--shielded-vtpm",
@@ -362,13 +406,69 @@ def _gce_create_cmd(
362406
return args
363407

364408

409+
def _gce_multi_flex_start(
410+
*,
411+
config: GpuConfig,
412+
base_instance_name: str,
413+
num_zones: int,
414+
path_to_metadata_script: dict[str, str],
415+
) -> None:
416+
"""
417+
Fan out async FLEX_START submits across the first `num_zones` of `config.standard_zones`, all sharing the same
418+
`launch-id` metadata. First VM to boot claims a GCS atomic-create lock (see `bin/_startup.sh`); losers self-delete.
419+
"""
420+
launch_id = uuid.uuid4().hex[:12]
421+
zones = config.standard_zones[:num_zones]
422+
logger.info(f"Multi-flex-start launch-id={launch_id}, fanning out to {len(zones)} zones: {zones}")
423+
n_submitted = 0
424+
last_stockout_stderr = ""
425+
for zone in zones:
426+
gce_create_args = _gce_create_cmd(
427+
config,
428+
base_instance_name, # instance names only need to be unique within (project, zone).
429+
zone,
430+
provisioning_model="FLEX_START",
431+
wait_for_instance_creation=False,
432+
path_to_metadata_script=path_to_metadata_script,
433+
launch_id=launch_id,
434+
)
435+
result = subprocess.run(gce_create_args, capture_output=True, text=True)
436+
if result.returncode == 0:
437+
logger.info(f"Flex-started {base_instance_name} in zone {zone}")
438+
n_submitted += 1
439+
continue
440+
441+
if _is_stockout(result.stderr): # I've gotten a stockout on flex-starts before
442+
logger.warning(f"Stockout in {zone}. Continuing with remaining zones")
443+
last_stockout_stderr = result.stderr
444+
continue
445+
446+
# Non-stockout error: fail loudly. Any earlier siblings already submitted will still race for the lock.
447+
if n_submitted > 0:
448+
logger.warning(
449+
f"{n_submitted} sibling(s) already submitted will still race for the lock; "
450+
"you may still end up with a VM despite this error."
451+
)
452+
_raise_gce_create_failure(result, gce_create_args)
453+
454+
if n_submitted == 0:
455+
suffix = f" Last stderr:\n{last_stockout_stderr}" if last_stockout_stderr else ""
456+
raise RuntimeError(f"Multi-flex-start: all {len(zones)} zones stocked out.{suffix}")
457+
logger.info(
458+
f"Multi-flex-start launched {n_submitted}/{len(zones)} VMs (launch-id={launch_id}). "
459+
f"First to boot claims the winner lock, others self-delete."
460+
)
461+
462+
365463
def gce_vm(
366464
*,
367465
gpu: GpuType,
368466
# TODO: support StrEnum-typed params in tap
369467
job_type: Literal[tuple(member.value for member in JobType)] = JobType.SSH, # type: ignore[valid-type]
370468
name_suffix: str = "",
371469
sync_start: bool = False,
470+
multi_flex_start: bool = False,
471+
multi_flex_start_num_zones: int = 10,
372472
command: str | None = None,
373473
zone: str | None = None,
374474
num_cycles_through_zones: int = 5,
@@ -379,7 +479,9 @@ def gce_vm(
379479
`_startup.sh` `eval`s it after env setup. If not, the instance just sets up the env and stops for when you want to
380480
SSH in and iterate.
381481
382-
The instance's name is `grouping-trainer-{gpu}-{job_type}[-{name_suffix}]` (the suffix is dropped when empty).
482+
The instance's name is `grouping-trainer-{gpu}-{job_type}[-{name_suffix}]` (the suffix is dropped when empty). In
483+
`--multi_flex_start` mode, siblings share the same name — GCE instance names are unique per (project, zone), and
484+
the zone column of `gcloud compute instances list` already distinguishes them.
383485
384486
Parameters
385487
----------
@@ -393,24 +495,35 @@ def gce_vm(
393495
sync_start
394496
If False (default), flex-starts the instance—GCP waits up to 1h to find one. `--sync_start` uses on-demand
395497
pricing and finds an instance in any zone, as flex-starting often can't find instances in time.
498+
multi_flex_start
499+
Fan out async FLEX_START submits across the first `multi_flex_start_num_zones` zones of `standard_zones`. First
500+
VM to boot wins via a GCS atomic-create lock; losers self-delete. Mutually exclusive with `sync_start`. Not
501+
applicable to GPU types without a `flex_start_zone` (e.g. L4).
502+
multi_flex_start_num_zones
503+
How many zones to fan out to in multi-flex mode. Capped at the length of the GPU config's `standard_zones`.
396504
command
397505
The command to run on the remote instance. If not given, the instance just sets up the env and stops for when
398506
you want to SSH in and iterate.
399507
zone
400508
Override the GCP zone. In FLEX_START mode pins the single submit zone; in STANDARD mode pins the loop to that
401-
zone (still retried `num_cycles_through_zones` times).
509+
zone (still retried `num_cycles_through_zones` times). Ignored in `--multi_flex_start` mode.
402510
num_cycles_through_zones
403511
No-op for sync_start. Otherwise, loop through zones this many times before giving up.
404512
seconds_between_gce_create_attempts
405513
No-op for sync_start. Otherwise, sleep this many seconds b/t consecutive zone attempts.
406514
"""
407515
config = gpu_type_to_config[gpu]
516+
if multi_flex_start and sync_start:
517+
raise ValueError("--multi_flex_start and --sync_start are mutually exclusive")
518+
if multi_flex_start and config.flex_start_zone is None:
519+
raise ValueError(f"GPU type {gpu!r} does not support FLEX_START, so --multi_flex_start is not applicable")
520+
408521
instance_name = f"{_INSTANCE_NAME_PREFIX}-{gpu}-{job_type}"
409522
if name_suffix:
410523
# GCE instance names must match `[a-z]([-a-z0-9]*[a-z0-9])?`, so swap underscores for hyphens.
411524
instance_name += f"-{name_suffix.replace('_', '-')}"
412525

413-
if (not sync_start) and (config.flex_start_zone is not None):
526+
if (not sync_start) and (not multi_flex_start) and (config.flex_start_zone is not None):
414527
provisioning_model: Literal["FLEX_START", "STANDARD"] = "FLEX_START"
415528
wait_for_instance_creation = False
416529
zones_to_try: tuple[str, ...] = (zone or config.flex_start_zone,)
@@ -431,6 +544,15 @@ def gce_vm(
431544
cmd_file.flush()
432545
path_to_metadata_script["command"] = cmd_file.name
433546

547+
if multi_flex_start:
548+
_gce_multi_flex_start(
549+
config=config,
550+
base_instance_name=instance_name,
551+
num_zones=multi_flex_start_num_zones,
552+
path_to_metadata_script=path_to_metadata_script,
553+
)
554+
return
555+
434556
for zone_idx, zone in enumerate(zones_to_try):
435557
# Attempt creation in this zone
436558
gce_create_args = _gce_create_cmd(
@@ -444,7 +566,8 @@ def gce_vm(
444566
logger.info(f"Attempting to create {instance_name} in zone {zone}")
445567
gce_create_cmd_result = subprocess.run(gce_create_args, capture_output=True, text=True)
446568
if gce_create_cmd_result.returncode == 0:
447-
logger.info(f"Created {instance_name} in zone {zone}")
569+
creation_type = "Flex-started" if provisioning_model == "FLEX_START" else "Created"
570+
logger.info(f"{creation_type} {instance_name} in zone {zone}")
448571
return
449572

450573
# Retry next zone if stockout
@@ -455,16 +578,9 @@ def gce_vm(
455578
continue
456579

457580
# Fail
458-
logger.error(f"gcloud failed (exit {gce_create_cmd_result.returncode}):\n{gce_create_cmd_result.stderr}")
459581
if not sync_start:
460-
logger.warning("You may have success with --sync_start and no GCP zone pinning.")
461-
# CalledProcessError's repr drops stderr, so log it
462-
raise subprocess.CalledProcessError(
463-
gce_create_cmd_result.returncode,
464-
gce_create_args,
465-
output=gce_create_cmd_result.stdout,
466-
stderr=gce_create_cmd_result.stderr,
467-
)
582+
logger.warning("You may have success with --multi_flex_start if you're fine waiting, else --sync_start")
583+
_raise_gce_create_failure(gce_create_cmd_result, gce_create_args)
468584

469585

470586
def run_argv_remotely(
@@ -473,6 +589,7 @@ def run_argv_remotely(
473589
job_type: JobType,
474590
name_suffix: str,
475591
sync_start: bool = False,
592+
multi_flex_start: bool = False,
476593
zone: str | None = None,
477594
env_var_to_value: dict[str, str] | None = None,
478595
) -> None:
@@ -513,7 +630,7 @@ def run_argv_remotely(
513630
# Script path and args
514631
script_path = os.path.relpath(os.path.abspath(sys.argv[0]), _repo_root())
515632
argv_remote = _strip_flags_and_their_values(argv=sys.argv[1:], flags=("--gpu", "--zone"))
516-
argv_remote = _strip_bool_flags(argv=argv_remote, flags=("--sync_start",))
633+
argv_remote = _strip_bool_flags(argv=argv_remote, flags=("--sync_start", "--multi_flex_start"))
517634
args_for_remote = shlex.join(argv_remote)
518635

519636
command_parts.append(script_path)
@@ -527,10 +644,12 @@ def run_argv_remotely(
527644
job_type=job_type,
528645
name_suffix=name_suffix,
529646
sync_start=sync_start,
647+
multi_flex_start=multi_flex_start,
530648
command=command,
531649
zone=zone,
532650
)
533651

534652

535653
if __name__ == "__main__":
654+
configure_logging(process_type="launch")
536655
tapify(gce_vm)

0 commit comments

Comments
 (0)