55import importlib
66import logging
77import os
8+ import urllib .parse
89from copy import deepcopy
9- from dataclasses import dataclass , replace
1010from collections .abc import Callable
11+ from dataclasses import dataclass , replace
1112from typing import TypeVar
1213
1314import draccus
2324)
2425from mergedeep import mergedeep
2526
26- from rigging .filesystem import check_gcs_paths_same_region , marin_temp_bucket
27+ from rigging .filesystem import check_gcs_paths_same_region , marin_temp_bucket , marin_temp_bucket_for_prefix
2728from marin .training .run_environment import add_run_env_variables
2829
2930logger = logging .getLogger (__name__ )
@@ -84,12 +85,34 @@ class TrainDpoOnPodConfig:
8485
8586DEFAULT_CHECKPOINTS_PATH = "checkpoints"
8687DEFAULT_HF_CHECKPOINTS_PATH = "hf"
88+ TEMPORARY_CHECKPOINT_TTL_DAYS = 14
89+ TEMPORARY_CHECKPOINTS_PATH = "checkpoints-temp"
8790
8891
8992def _cli_helpers_module ():
9093 return importlib .import_module ("levanter.infra.cli_helpers" )
9194
9295
96+ def _output_path_temp_component (output_path : str ) -> str :
97+ parsed = urllib .parse .urlparse (output_path )
98+ if parsed .scheme and parsed .netloc :
99+ return f"{ parsed .netloc } { parsed .path } " .strip ("/" )
100+ if parsed .scheme :
101+ return f"{ parsed .scheme } { parsed .path } " .strip ("/" )
102+ return output_path .strip ("/" )
103+
104+
105+ def temporary_checkpoint_base_path (output_path : str ) -> str :
106+ """Return the region-local temporary checkpoint base for an executor output path."""
107+ output_component = _output_path_temp_component (output_path )
108+ temp_prefix = os .path .join (TEMPORARY_CHECKPOINTS_PATH , output_component , DEFAULT_CHECKPOINTS_PATH )
109+ return marin_temp_bucket_for_prefix (
110+ ttl_days = TEMPORARY_CHECKPOINT_TTL_DAYS ,
111+ source_prefix = output_path ,
112+ prefix = temp_prefix ,
113+ )
114+
115+
93116def _update_config_to_use_out_path (pod_config : TrainOnPodConfigT ) -> TrainOnPodConfigT :
94117 """
95118 Update the config to use the out_path as the base output directory for training.
@@ -109,7 +132,7 @@ def _update_config_to_use_out_path(pod_config: TrainOnPodConfigT) -> TrainOnPodC
109132 checkpointer = replace (
110133 pod_config .train_config .trainer .checkpointer ,
111134 base_path = os .path .join (pod_config .output_path , DEFAULT_CHECKPOINTS_PATH ),
112- temporary_base_path = marin_temp_bucket ( ttl_days = 14 , prefix = "checkpoints-temp" ),
135+ temporary_base_path = temporary_checkpoint_base_path ( pod_config . output_path ),
113136 ),
114137 )
115138
0 commit comments