Skip to content

Commit 6014f9e

Browse files
committed
change zones
1 parent 6d23bc3 commit 6014f9e

3 files changed

Lines changed: 18 additions & 16 deletions

File tree

bin/_startup.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
# after SSH'ing in, run `sudo -i` first so $HOME=/root and paths line up.
55
set -euo pipefail
66

7-
# Install uv (manages its own Python; respects .python-version in the repo).
7+
# GCP's metadata script runner doesn't export HOME
8+
export HOME="${HOME:-/root}"
9+
10+
# Install uv (manages its own Python, respects .python-version in the repo).
811
curl -LsSf https://astral.sh/uv/install.sh | sh
912
export PATH="$HOME/.local/bin:$PATH"
1013

src/grouping_trainer/launch.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,14 @@ class GpuConfig(BaseModel):
4141
model_config = ConfigDict(frozen=True, extra="forbid")
4242

4343
name: str
44-
zone: str # default zone; flex-start capacity varies across regions, so this
45-
# gets overridden via the --zone flag when the default is dry.
44+
zone: str = "us-central1-a"
4645
machine_type: str
47-
accelerator: str | None # None for *-ddp variants — accelerators are built
48-
# into the machine type, so passing --accelerator is redundant/erroneous.
46+
accelerator: str | None # None for *-ddp variants b/c accelerators are built into the machine type
4947
max_run: str
5048
install_nvidia_driver: bool
5149
reservation_affinity: Literal["none", "any"]
52-
wait: bool # whether to block locally on instance creation. False adds --async.
50+
wait_for_instance_creation: bool
51+
is_for_training: bool
5352

5453

5554
gpu_type_to_config: dict[GpuType, GpuConfig] = {
@@ -61,7 +60,7 @@ class GpuConfig(BaseModel):
6160
max_run="86400s",
6261
install_nvidia_driver=False,
6362
reservation_affinity="any",
64-
wait=True, # L4s come up fast; block so errors surface promptly
63+
wait_for_instance_creation=True, # L4s come up fast. Block so errors surface promptly
6564
),
6665
"h100": GpuConfig(
6766
name="grouping-trainer-h100",
@@ -71,7 +70,7 @@ class GpuConfig(BaseModel):
7170
max_run="86400s",
7271
install_nvidia_driver=True,
7372
reservation_affinity="none",
74-
wait=False, # flex-start can queue for up to 1h; don't block the shell
73+
wait_for_instance_creation=False, # flex-start can queue for up to 1h
7574
),
7675
"h100-ddp": GpuConfig(
7776
name="grouping-trainer-h100-ddp",
@@ -81,7 +80,7 @@ class GpuConfig(BaseModel):
8180
max_run="172800s",
8281
install_nvidia_driver=True,
8382
reservation_affinity="none",
84-
wait=False,
83+
wait_for_instance_creation=False,
8584
),
8685
"a100": GpuConfig(
8786
name="grouping-trainer-a100",
@@ -91,7 +90,7 @@ class GpuConfig(BaseModel):
9190
max_run="86400s",
9291
install_nvidia_driver=True,
9392
reservation_affinity="none",
94-
wait=False,
93+
wait_for_instance_creation=False,
9594
),
9695
"a100-ddp": GpuConfig(
9796
name="grouping-trainer-a100-ddp",
@@ -101,7 +100,7 @@ class GpuConfig(BaseModel):
101100
max_run="172800s",
102101
install_nvidia_driver=True,
103102
reservation_affinity="none",
104-
wait=False,
103+
wait_for_instance_creation=False,
105104
),
106105
}
107106

train.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def upload_run_metadata(run_gcs_dir: str, training_config: gt.train.TrainingConf
5151
"jinaai/jina-embeddings-v5-text-nano-text-matching": 4,
5252
}
5353

54+
TrainingGpuType = Literal[tuple(gpu_type for gpu_type in gt.launch.gpu_type_to_config.keys() if gpu_type != "l4")]
55+
5456

5557
def run(
5658
base_model: str = "lightonai/modernbert-embed-large",
@@ -60,7 +62,7 @@ def run(
6062
per_device_train_batch_size: int = 256,
6163
learning_rate: float = 1e-4,
6264
tiny_run: bool = False,
63-
gpu: Literal["h100", "h100-ddp", "a100", "a100-ddp"] | None = None,
65+
gpu: TrainingGpuType | None = None,
6466
zone: str | None = None,
6567
):
6668
"""
@@ -95,10 +97,8 @@ def run(
9597
if not tiny_run:
9698
assert run_shortname is not None, "run_shortname is required for full training runs"
9799

98-
# Generate run_name up front so we can log the artifact URL locally before
99-
# auto-launching. On the remote, re-use the local run_name via env var so
100-
# both sides log the same GCS path (rather than each generating its own
101-
# timestamp).
100+
# Generate run_name up front so we can log the artifact URL locally before auto-launching. On the remote, re-use the
101+
# local run_name via env var so both sides log the same GCS path (rather than each generating its own timestamp).
102102
run_name_env = os.environ.get(_RUN_NAME_ENV_VAR)
103103
if run_name_env:
104104
run_name = run_name_env

0 commit comments

Comments
 (0)