Skip to content

Commit 990af90

Browse files
committed
Wire temporary checkpoint roots into Marin launches
1 parent c05d339 commit 990af90

8 files changed

Lines changed: 151 additions & 9 deletions

File tree

experiments/grug/base/launch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from levanter.trainer import TrainerConfig
2424
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
2525
from marin.processing.tokenize import add_validation_sets_to_mixture
26+
from marin.training.training import temporary_checkpoint_base_path
2627

2728
from experiments.defaults import default_validation_sets
2829
from experiments.grug.base.model import GrugModelConfig
@@ -99,6 +100,7 @@ def run_grug_base_trial(config: GrugBaseLaunchConfig) -> None:
99100
allow_nondivisible_batch_size=False,
100101
checkpointer=CheckpointerConfig(
101102
base_path=os.path.join(config.output_path, "checkpoints"),
103+
temporary_base_path=temporary_checkpoint_base_path(config.output_path),
102104
append_run_id_to_base_path=False,
103105
save_interval=timedelta(minutes=10),
104106
keep=[{"every": 1000}],

experiments/grug/modular_opt/launch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from levanter.utils.jax_utils import leaf_key_paths
2727
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
2828
from marin.processing.tokenize import add_validation_sets_to_mixture
29+
from marin.training.training import temporary_checkpoint_base_path
2930

3031
from experiments.defaults import default_validation_sets
3132
from experiments.grug.modular_opt.model import GrugModelConfig
@@ -205,6 +206,7 @@ def run_grug_modular_opt_trial(config: GrugModularOptLaunchConfig) -> None:
205206
allow_nondivisible_batch_size=False,
206207
checkpointer=CheckpointerConfig(
207208
base_path=os.path.join(config.output_path, "checkpoints"),
209+
temporary_base_path=temporary_checkpoint_base_path(config.output_path),
208210
append_run_id_to_base_path=False,
209211
save_interval=timedelta(minutes=10),
210212
keep=[{"every": 1000}],

experiments/grug/moe/launch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from levanter.utils.mesh import MeshConfig
2525
from marin.execution.executor import ExecutorStep, executor_main, this_output_path, versioned
2626
from marin.processing.tokenize import add_validation_sets_to_mixture
27+
from marin.training.training import temporary_checkpoint_base_path
2728

2829
from experiments.defaults import default_validation_sets
2930
from experiments.grug.moe.heuristic import build_from_heuristic
@@ -92,6 +93,7 @@ def run_grug_moe_trial(config: GrugMoeLaunchConfig) -> None:
9293
allow_nondivisible_batch_size=False,
9394
checkpointer=CheckpointerConfig(
9495
base_path=os.path.join(config.output_path, "checkpoints"),
96+
temporary_base_path=temporary_checkpoint_base_path(config.output_path),
9597
append_run_id_to_base_path=False,
9698
save_interval=timedelta(minutes=10),
9799
keep=[{"every": 1000}],

lib/iris/tests/test_marin_fs.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
marin_prefix,
2222
marin_region,
2323
marin_temp_bucket,
24+
marin_temp_bucket_for_prefix,
2425
open_url,
2526
region_from_metadata,
2627
region_from_prefix,
@@ -136,6 +137,18 @@ def test_marin_temp_bucket_from_env_prefix():
136137
assert marin_temp_bucket(ttl_days=3, prefix="zephyr") == "gs://marin-tmp-us-east1/ttl=3d/zephyr"
137138

138139

140+
def test_marin_temp_bucket_for_prefix_uses_source_region():
141+
with (
142+
patch("rigging.filesystem.urllib.request.urlopen", side_effect=OSError("not on GCP")),
143+
patch.dict(os.environ, {"MARIN_PREFIX": "gs://marin-us-central1/scratch"}),
144+
):
145+
assert marin_temp_bucket_for_prefix(
146+
ttl_days=14,
147+
source_prefix="gs://marin-us-east5/experiments/grug/run",
148+
prefix="checkpoints-temp/marin-us-east5/experiments/grug/run/checkpoints",
149+
) == ("gs://marin-tmp-us-east5/ttl=14d/" "checkpoints-temp/marin-us-east5/experiments/grug/run/checkpoints")
150+
151+
139152
def test_marin_temp_bucket_falls_back_to_marin_prefix_when_no_region():
140153
# Unknown region in MARIN_PREFIX → no entry in REGION_TO_TMP_BUCKET → falls back to marin_prefix/tmp
141154
with (

lib/marin/src/marin/training/training.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import importlib
66
import logging
77
import os
8+
import urllib.parse
89
from copy import deepcopy
9-
from dataclasses import dataclass, replace
1010
from collections.abc import Callable
11+
from dataclasses import dataclass, replace
1112
from typing import TypeVar
1213

1314
import draccus
@@ -23,7 +24,7 @@
2324
)
2425
from mergedeep import mergedeep
2526

26-
from rigging.filesystem import check_gcs_paths_same_region, marin_temp_bucket
27+
from rigging.filesystem import check_gcs_paths_same_region, marin_temp_bucket, marin_temp_bucket_for_prefix
2728
from marin.training.run_environment import add_run_env_variables
2829

2930
logger = logging.getLogger(__name__)
@@ -84,12 +85,34 @@ class TrainDpoOnPodConfig:
8485

8586
DEFAULT_CHECKPOINTS_PATH = "checkpoints"
8687
DEFAULT_HF_CHECKPOINTS_PATH = "hf"
88+
TEMPORARY_CHECKPOINT_TTL_DAYS = 14
89+
TEMPORARY_CHECKPOINTS_PATH = "checkpoints-temp"
8790

8891

8992
def _cli_helpers_module():
9093
return importlib.import_module("levanter.infra.cli_helpers")
9194

9295

96+
def _output_path_temp_component(output_path: str) -> str:
97+
parsed = urllib.parse.urlparse(output_path)
98+
if parsed.scheme and parsed.netloc:
99+
return f"{parsed.netloc}{parsed.path}".strip("/")
100+
if parsed.scheme:
101+
return f"{parsed.scheme}{parsed.path}".strip("/")
102+
return output_path.strip("/")
103+
104+
105+
def temporary_checkpoint_base_path(output_path: str) -> str:
106+
"""Return the region-local temporary checkpoint base for an executor output path."""
107+
output_component = _output_path_temp_component(output_path)
108+
temp_prefix = os.path.join(TEMPORARY_CHECKPOINTS_PATH, output_component, DEFAULT_CHECKPOINTS_PATH)
109+
return marin_temp_bucket_for_prefix(
110+
ttl_days=TEMPORARY_CHECKPOINT_TTL_DAYS,
111+
source_prefix=output_path,
112+
prefix=temp_prefix,
113+
)
114+
115+
93116
def _update_config_to_use_out_path(pod_config: TrainOnPodConfigT) -> TrainOnPodConfigT:
94117
"""
95118
Update the config to use the out_path as the base output directory for training.
@@ -109,7 +132,7 @@ def _update_config_to_use_out_path(pod_config: TrainOnPodConfigT) -> TrainOnPodC
109132
checkpointer=replace(
110133
pod_config.train_config.trainer.checkpointer,
111134
base_path=os.path.join(pod_config.output_path, DEFAULT_CHECKPOINTS_PATH),
112-
temporary_base_path=marin_temp_bucket(ttl_days=14, prefix="checkpoints-temp"),
135+
temporary_base_path=temporary_checkpoint_base_path(pod_config.output_path),
113136
),
114137
)
115138

lib/rigging/src/rigging/filesystem.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,12 @@ def marin_region() -> str | None:
157157
return region_from_metadata() or region_from_prefix(os.environ.get("MARIN_PREFIX", ""))
158158

159159

160+
def _append_path_prefix(path: str, prefix: str) -> str:
161+
if prefix:
162+
return f"{path}/{prefix.strip('/')}"
163+
return path
164+
165+
160166
def marin_temp_bucket(ttl_days: int, prefix: str = "") -> str:
161167
"""Return a path on region-local temp storage. Never returns ``None``.
162168
@@ -186,16 +192,26 @@ def marin_temp_bucket(ttl_days: int, prefix: str = "") -> str:
186192
bucket = REGION_TO_TMP_BUCKET.get(region)
187193
if bucket:
188194
path = f"gs://{bucket}/ttl={ttl_days}d"
189-
if prefix:
190-
path = f"{path}/{prefix.strip('/')}"
191-
return path
195+
return _append_path_prefix(path, prefix)
192196

193197
if "://" not in mp:
194198
mp = f"file://{mp}"
195199
path = f"{mp}/tmp"
196-
if prefix:
197-
path = f"{path}/{prefix.strip('/')}"
198-
return path
200+
return _append_path_prefix(path, prefix)
201+
202+
203+
def marin_temp_bucket_for_prefix(ttl_days: int, source_prefix: str, prefix: str = "") -> str:
204+
"""Return temp storage in the same region as ``source_prefix`` when possible.
205+
206+
This is useful when configuring a remote job from a launcher that may be in
207+
a different region than the job output path.
208+
"""
209+
region = region_from_prefix(source_prefix)
210+
if region:
211+
bucket = REGION_TO_TMP_BUCKET.get(region)
212+
if bucket:
213+
return _append_path_prefix(f"gs://{bucket}/ttl={ttl_days}d", prefix)
214+
return marin_temp_bucket(ttl_days, prefix=prefix)
199215

200216

201217
# ---------------------------------------------------------------------------
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import os
5+
from typing import Any
6+
from unittest.mock import patch
7+
8+
from fray.cluster import ResourceConfig
9+
from levanter.optim import AdamConfig
10+
from levanter.tracker import NoopConfig
11+
12+
from experiments.grug.base.launch import GRUG_130M_MODEL, GrugBaseLaunchConfig, run_grug_base_trial
13+
14+
_DUMMY_DATA: Any = object()
15+
16+
17+
def test_grug_base_launch_sets_temporary_checkpoint_base_path():
18+
with (
19+
patch("rigging.filesystem.urllib.request.urlopen", side_effect=OSError("not on GCP")),
20+
patch.dict(os.environ, {"MARIN_PREFIX": "gs://marin-us-central1/scratch"}),
21+
patch("experiments.grug.base.launch.run_grug") as run_grug,
22+
):
23+
run_grug_base_trial(
24+
GrugBaseLaunchConfig(
25+
model=GRUG_130M_MODEL,
26+
data=_DUMMY_DATA,
27+
output_path="gs://marin-us-east5/experiments/grug/base-trial",
28+
run_id="grug-temp-path-test",
29+
resources=ResourceConfig.with_cpu(),
30+
steps=1,
31+
batch_size=1,
32+
seed=0,
33+
mp="params=float32,compute=bfloat16,output=bfloat16",
34+
tracker=NoopConfig(),
35+
optimizer=AdamConfig(),
36+
eval=None,
37+
)
38+
)
39+
40+
run_config = run_grug.call_args.args[0]
41+
checkpointer = run_config.trainer.trainer.checkpointer
42+
assert checkpointer.base_path == "gs://marin-us-east5/experiments/grug/base-trial/checkpoints"
43+
assert checkpointer.temporary_base_path == (
44+
"gs://marin-tmp-us-east5/ttl=14d/" "checkpoints-temp/marin-us-east5/experiments/grug/base-trial/checkpoints"
45+
)
46+
assert run_config.trainer.trainer.checkpoint_search_paths("grug-temp-path-test") == [
47+
"gs://marin-us-east5/experiments/grug/base-trial/checkpoints",
48+
"gs://marin-tmp-us-east5/ttl=14d/" "checkpoints-temp/marin-us-east5/experiments/grug/base-trial/checkpoints",
49+
]

tests/test_training.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: Apache-2.0
33

44
import dataclasses
5+
import os
56
from pathlib import Path
67
from unittest.mock import patch
78

@@ -14,6 +15,8 @@
1415
from marin.training.training import (
1516
TrainLmOnPodConfig,
1617
_doublecheck_paths,
18+
_update_config_to_use_out_path,
19+
temporary_checkpoint_base_path,
1720
)
1821

1922

@@ -64,6 +67,38 @@ def test_lm_config_with_train_urls_allowed_out_of_region(trainer_config):
6467
_doublecheck_paths(config)
6568

6669

70+
def test_temporary_checkpoint_base_path_follows_output_path_region():
71+
with (
72+
patch("rigging.filesystem.urllib.request.urlopen", side_effect=OSError("not on GCP")),
73+
patch.dict(os.environ, {"MARIN_PREFIX": "gs://marin-us-central1/scratch"}),
74+
):
75+
assert temporary_checkpoint_base_path("gs://marin-us-east5/experiments/grug/base-trial") == (
76+
"gs://marin-tmp-us-east5/ttl=14d/" "checkpoints-temp/marin-us-east5/experiments/grug/base-trial/checkpoints"
77+
)
78+
79+
80+
def test_update_config_to_use_out_path_sets_run_specific_temp_checkpoints(trainer_config):
81+
with (
82+
patch("rigging.filesystem.urllib.request.urlopen", side_effect=OSError("not on GCP")),
83+
patch.dict(os.environ, {"MARIN_PREFIX": "gs://marin-us-central1/scratch"}),
84+
):
85+
config = TrainLmOnPodConfig(
86+
train_config=train_lm.TrainLmConfig(
87+
trainer=trainer_config,
88+
),
89+
resources=ResourceConfig.with_tpu("v4-8"),
90+
output_path="gs://marin-us-east5/experiments/grug/base-trial",
91+
)
92+
93+
updated = _update_config_to_use_out_path(config)
94+
95+
checkpointer = updated.train_config.trainer.checkpointer
96+
assert checkpointer.base_path == "gs://marin-us-east5/experiments/grug/base-trial/checkpoints"
97+
assert checkpointer.temporary_base_path == (
98+
"gs://marin-tmp-us-east5/ttl=14d/" "checkpoints-temp/marin-us-east5/experiments/grug/base-trial/checkpoints"
99+
)
100+
101+
67102
def test_recursive_path_checking(trainer_config):
68103
"""Paths are checked recursively in nested structures."""
69104
with (

0 commit comments

Comments
 (0)