2020import sys
2121import tempfile
2222import time
23+ import uuid
2324from datetime import datetime
2425from enum import StrEnum
25- from typing import Literal
26+ from typing import Literal , NoReturn
2627
2728from pydantic import BaseModel , ConfigDict
2829from tap import tapify
2930
31+ from grouping_trainer .logging import configure_logging
32+
3033logger = logging .getLogger (__name__ )
3134
3235_REMOTE_ENV_VAR = "GROUPING_TRAINER_REMOTE"
@@ -87,9 +90,11 @@ class GpuConfig(BaseModel):
8790 reservation_affinity : Literal ["none" , "any" ]
8891 is_for_training : bool
8992 n_gpu : int
93+ boot_disk_type : str = "pd-balanced"
94+ """Boot disk type. A3 Ultra (a3-ultragpu-8g) rejects pd-balanced and needs hyperdisk-balanced."""
9095
9196
92- # gcloud compute accelerator-types list --filter="name=nvidia-l4" --format='value(zone)'
97+ # gcloud compute accelerator-types list --filter="name=nvidia-l4" --format='value(zone)' | sort -u
9398L4_ZONES = (
9499 "us-central1-a" ,
95100 "us-central1-b" ,
@@ -168,6 +173,18 @@ class GpuConfig(BaseModel):
168173 "us-east4-c" ,
169174 "europe-west4-a" ,
170175)
176+ # gcloud compute accelerator-types list --filter="name=nvidia-h200-141gb" --format='value(zone)'
177+ H200_ZONES = (
178+ "us-central1-b" ,
179+ "us-east4-b" ,
180+ "us-east5-a" ,
181+ "us-south1-b" ,
182+ "us-west1-c" ,
183+ "asia-south1-b" ,
184+ "asia-south2-c" ,
185+ "europe-west1-b" ,
186+ "europe-west4-a" ,
187+ )
171188
172189GpuType = Literal [
173190 "l4" ,
@@ -177,6 +194,7 @@ class GpuConfig(BaseModel):
177194 "a100" ,
178195 "a100-ddp-2" ,
179196 "a100-ddp-4" ,
197+ "h200-ddp-8" ,
180198]
181199
182200gpu_type_to_config : dict [GpuType , GpuConfig ] = {
@@ -257,6 +275,18 @@ class GpuConfig(BaseModel):
257275 is_for_training = True ,
258276 n_gpu = 4 ,
259277 ),
278+ "h200-ddp-8" : GpuConfig (
279+ flex_start_zone = "us-central1-a" ,
280+ standard_zones = H200_ZONES ,
281+ machine_type = "a3-ultragpu-8g" ,
282+ accelerator = None ,
283+ max_run_duration = "172800s" ,
284+ install_nvidia_driver = True ,
285+ reservation_affinity = "none" ,
286+ is_for_training = True ,
287+ n_gpu = 8 ,
288+ boot_disk_type = "hyperdisk-balanced" ,
289+ ),
260290}
261291
262292
@@ -314,6 +344,17 @@ def _is_stockout(stderr: str) -> bool:
314344 return "ZONE_RESOURCE_POOL_EXHAUSTED" in stderr
315345
316346
347+ def _raise_gce_create_failure (result : subprocess .CompletedProcess [str ], args : list [str ]) -> NoReturn :
348+ """Log the gcloud stderr (CalledProcessError's repr drops it) and raise."""
349+ logger .error (f"gcloud failed (exit { result .returncode } ):\n { result .stderr } " )
350+ raise subprocess .CalledProcessError (
351+ result .returncode ,
352+ args ,
353+ output = result .stdout ,
354+ stderr = result .stderr ,
355+ )
356+
357+
317358def _gce_create_cmd (
318359 config : GpuConfig ,
319360 instance_name : str ,
@@ -322,11 +363,14 @@ def _gce_create_cmd(
322363 provisioning_model : Literal ["FLEX_START" , "STANDARD" ],
323364 wait_for_instance_creation : bool ,
324365 path_to_metadata_script : dict [str , str ],
366+ launch_id : str | None = None ,
325367) -> list [str ]:
326368 metadata = {"gcs-bucket" : os .environ ["GROUPING_TRAINER_BUCKET" ]}
327369 if config .install_nvidia_driver :
328370 metadata ["enable-osconfig" ] = "TRUE"
329371 metadata ["install-nvidia-driver" ] = "True"
372+ if launch_id is not None :
373+ metadata ["launch-id" ] = launch_id
330374
331375 args = [
332376 "gcloud" ,
@@ -346,7 +390,7 @@ def _gce_create_cmd(
346390 "--scopes=https://www.googleapis.com/auth/cloud-platform" ,
347391 (
348392 f"--create-disk=auto-delete=yes,boot=yes,device-name={ instance_name } ,"
349- f"image={ _IMAGE } ,mode=rw,size=200,type=pd-balanced "
393+ f"image={ _IMAGE } ,mode=rw,size=200,type={ config . boot_disk_type } "
350394 ),
351395 "--no-shielded-secure-boot" ,
352396 "--shielded-vtpm" ,
@@ -362,13 +406,69 @@ def _gce_create_cmd(
362406 return args
363407
364408
409+ def _gce_multi_flex_start (
410+ * ,
411+ config : GpuConfig ,
412+ base_instance_name : str ,
413+ num_zones : int ,
414+ path_to_metadata_script : dict [str , str ],
415+ ) -> None :
416+ """
417+ Fan out async FLEX_START submits across the first `num_zones` of `config.standard_zones`, all sharing the same
418+ `launch-id` metadata. First VM to boot claims a GCS atomic-create lock (see `bin/_startup.sh`); losers self-delete.
419+ """
420+ launch_id = uuid .uuid4 ().hex [:12 ]
421+ zones = config .standard_zones [:num_zones ]
422+ logger .info (f"Multi-flex-start launch-id={ launch_id } , fanning out to { len (zones )} zones: { zones } " )
423+ n_submitted = 0
424+ last_stockout_stderr = ""
425+ for zone in zones :
426+ gce_create_args = _gce_create_cmd (
427+ config ,
428+ base_instance_name , # instance names only need to be unique within (project, zone).
429+ zone ,
430+ provisioning_model = "FLEX_START" ,
431+ wait_for_instance_creation = False ,
432+ path_to_metadata_script = path_to_metadata_script ,
433+ launch_id = launch_id ,
434+ )
435+ result = subprocess .run (gce_create_args , capture_output = True , text = True )
436+ if result .returncode == 0 :
437+ logger .info (f"Flex-started { base_instance_name } in zone { zone } " )
438+ n_submitted += 1
439+ continue
440+
441+ if _is_stockout (result .stderr ): # I've gotten a stockout on flex-starts before
442+ logger .warning (f"Stockout in { zone } . Continuing with remaining zones" )
443+ last_stockout_stderr = result .stderr
444+ continue
445+
446+ # Non-stockout error: fail loudly. Any earlier siblings already submitted will still race for the lock.
447+ if n_submitted > 0 :
448+ logger .warning (
449+ f"{ n_submitted } sibling(s) already submitted will still race for the lock; "
450+ "you may still end up with a VM despite this error."
451+ )
452+ _raise_gce_create_failure (result , gce_create_args )
453+
454+ if n_submitted == 0 :
455+ suffix = f" Last stderr:\n { last_stockout_stderr } " if last_stockout_stderr else ""
456+ raise RuntimeError (f"Multi-flex-start: all { len (zones )} zones stocked out.{ suffix } " )
457+ logger .info (
458+ f"Multi-flex-start launched { n_submitted } /{ len (zones )} VMs (launch-id={ launch_id } ). "
459+ f"First to boot claims the winner lock, others self-delete."
460+ )
461+
462+
365463def gce_vm (
366464 * ,
367465 gpu : GpuType ,
368466 # TODO: support StrEnum-typed params in tap
369467 job_type : Literal [tuple (member .value for member in JobType )] = JobType .SSH , # type: ignore[valid-type]
370468 name_suffix : str = "" ,
371469 sync_start : bool = False ,
470+ multi_flex_start : bool = False ,
471+ multi_flex_start_num_zones : int = 10 ,
372472 command : str | None = None ,
373473 zone : str | None = None ,
374474 num_cycles_through_zones : int = 5 ,
@@ -379,7 +479,9 @@ def gce_vm(
379479 `_startup.sh` `eval`s it after env setup. If not, the instance just sets up the env and stops for when you want to
380480 SSH in and iterate.
381481
382- The instance's name is `grouping-trainer-{gpu}-{job_type}[-{name_suffix}]` (the suffix is dropped when empty).
482+ The instance's name is `grouping-trainer-{gpu}-{job_type}[-{name_suffix}]` (the suffix is dropped when empty). In
483+ `--multi_flex_start` mode, siblings share the same name — GCE instance names are unique per (project, zone), and
484+ the zone column of `gcloud compute instances list` already distinguishes them.
383485
384486 Parameters
385487 ----------
@@ -393,24 +495,35 @@ def gce_vm(
393495 sync_start
394496 If False (default), flex-starts the instance—GCP waits up to 1h to find one. `--sync_start` uses on-demand
395497 pricing and finds an instance in any zone, as flex-starting often can't find instances in time.
498+ multi_flex_start
499+ Fan out async FLEX_START submits across the first `multi_flex_start_num_zones` zones of `standard_zones`. First
500+ VM to boot wins via a GCS atomic-create lock; losers self-delete. Mutually exclusive with `sync_start`. Not
501+ applicable to GPU types without a `flex_start_zone` (e.g. L4).
502+ multi_flex_start_num_zones
503+ How many zones to fan out to in multi-flex mode. Capped at the length of the GPU config's `standard_zones`.
396504 command
397505 The command to run on the remote instance. If not given, the instance just sets up the env and stops for when
398506 you want to SSH in and iterate.
399507 zone
400508 Override the GCP zone. In FLEX_START mode pins the single submit zone; in STANDARD mode pins the loop to that
401- zone (still retried `num_cycles_through_zones` times).
509+ zone (still retried `num_cycles_through_zones` times). Ignored in `--multi_flex_start` mode.
402510 num_cycles_through_zones
403511 No-op for sync_start. Otherwise, loop through zones this many times before giving up.
404512 seconds_between_gce_create_attempts
405513 No-op for sync_start. Otherwise, sleep this many seconds b/t consecutive zone attempts.
406514 """
407515 config = gpu_type_to_config [gpu ]
516+ if multi_flex_start and sync_start :
517+ raise ValueError ("--multi_flex_start and --sync_start are mutually exclusive" )
518+ if multi_flex_start and config .flex_start_zone is None :
519+ raise ValueError (f"GPU type { gpu !r} does not support FLEX_START, so --multi_flex_start is not applicable" )
520+
408521 instance_name = f"{ _INSTANCE_NAME_PREFIX } -{ gpu } -{ job_type } "
409522 if name_suffix :
410523 # GCE instance names must match `[a-z]([-a-z0-9]*[a-z0-9])?`, so swap underscores for hyphens.
411524 instance_name += f"-{ name_suffix .replace ('_' , '-' )} "
412525
413- if (not sync_start ) and (config .flex_start_zone is not None ):
526+ if (not sync_start ) and (not multi_flex_start ) and ( config .flex_start_zone is not None ):
414527 provisioning_model : Literal ["FLEX_START" , "STANDARD" ] = "FLEX_START"
415528 wait_for_instance_creation = False
416529 zones_to_try : tuple [str , ...] = (zone or config .flex_start_zone ,)
@@ -431,6 +544,15 @@ def gce_vm(
431544 cmd_file .flush ()
432545 path_to_metadata_script ["command" ] = cmd_file .name
433546
547+ if multi_flex_start :
548+ _gce_multi_flex_start (
549+ config = config ,
550+ base_instance_name = instance_name ,
551+ num_zones = multi_flex_start_num_zones ,
552+ path_to_metadata_script = path_to_metadata_script ,
553+ )
554+ return
555+
434556 for zone_idx , zone in enumerate (zones_to_try ):
435557 # Attempt creation in this zone
436558 gce_create_args = _gce_create_cmd (
@@ -444,7 +566,8 @@ def gce_vm(
444566 logger .info (f"Attempting to create { instance_name } in zone { zone } " )
445567 gce_create_cmd_result = subprocess .run (gce_create_args , capture_output = True , text = True )
446568 if gce_create_cmd_result .returncode == 0 :
447- logger .info (f"Created { instance_name } in zone { zone } " )
569+ creation_type = "Flex-started" if provisioning_model == "FLEX_START" else "Created"
570+ logger .info (f"{ creation_type } { instance_name } in zone { zone } " )
448571 return
449572
450573 # Retry next zone if stockout
@@ -455,16 +578,9 @@ def gce_vm(
455578 continue
456579
457580 # Fail
458- logger .error (f"gcloud failed (exit { gce_create_cmd_result .returncode } ):\n { gce_create_cmd_result .stderr } " )
459581 if not sync_start :
460- logger .warning ("You may have success with --sync_start and no GCP zone pinning." )
461- # CalledProcessError's repr drops stderr, so log it
462- raise subprocess .CalledProcessError (
463- gce_create_cmd_result .returncode ,
464- gce_create_args ,
465- output = gce_create_cmd_result .stdout ,
466- stderr = gce_create_cmd_result .stderr ,
467- )
582+ logger .warning ("You may have success with --multi_flex_start if you're fine waiting, else --sync_start" )
583+ _raise_gce_create_failure (gce_create_cmd_result , gce_create_args )
468584
469585
470586def run_argv_remotely (
@@ -473,6 +589,7 @@ def run_argv_remotely(
473589 job_type : JobType ,
474590 name_suffix : str ,
475591 sync_start : bool = False ,
592+ multi_flex_start : bool = False ,
476593 zone : str | None = None ,
477594 env_var_to_value : dict [str , str ] | None = None ,
478595) -> None :
@@ -513,7 +630,7 @@ def run_argv_remotely(
513630 # Script path and args
514631 script_path = os .path .relpath (os .path .abspath (sys .argv [0 ]), _repo_root ())
515632 argv_remote = _strip_flags_and_their_values (argv = sys .argv [1 :], flags = ("--gpu" , "--zone" ))
516- argv_remote = _strip_bool_flags (argv = argv_remote , flags = ("--sync_start" ,))
633+ argv_remote = _strip_bool_flags (argv = argv_remote , flags = ("--sync_start" , "--multi_flex_start" ))
517634 args_for_remote = shlex .join (argv_remote )
518635
519636 command_parts .append (script_path )
@@ -527,10 +644,12 @@ def run_argv_remotely(
527644 job_type = job_type ,
528645 name_suffix = name_suffix ,
529646 sync_start = sync_start ,
647+ multi_flex_start = multi_flex_start ,
530648 command = command ,
531649 zone = zone ,
532650 )
533651
534652
535653if __name__ == "__main__" :
654+ configure_logging (process_type = "launch" )
536655 tapify (gce_vm )
0 commit comments