66import os
77from copy import deepcopy
88from dataclasses import dataclass , replace
9+ from collections .abc import Callable
910from typing import TypeVar
1011
1112import draccus
@@ -218,22 +219,13 @@ def _disable_xla_autotune_subcache(env: dict) -> None:
218219 logger .info ("XLA sub-caches disabled (compilation cache is remote: %s)" , cache_dir )
219220
220221
221- def run_levanter_train_lm (config : TrainLmOnPodConfig ):
222- """
223- Run the Levanter training main function on a Ray cluster.
224-
225- This function is designed to be run on your machine or with sufficient variables in the env dict/os env.
226- It should also be run with a Ray cluster already running.
227-
228- - WANDB_API_KEY: The API key for Weights and Biases.
229- - RUN_ID: (Optional) The run ID for this training run. Will default to a random UID if not set.
230- - GIT_COMMIT: (Optional) The git commit hash of the current codebase. Will attempt to fetch it if not set.
222+ def _prepare_training_run (
223+ config : TrainOnPodConfigT ,
224+ ) -> tuple [TrainOnPodConfigT , TrainLmConfig | TrainDpoConfig , dict [str , str ], list [str ]]:
225+ """Shared setup for LM and DPO training: env vars, run ID, config adjustments.
231226
232- This function makes a number of changes to the config and ensures a few things are set:
233- - The run ID is set, or sets a default if not.
234- - WANDB_API_KEY is set.
235- - It disables the auto-ray-start and auto-worker-start options since we're already in a Ray cluster.
236- - It checks that configured GCS paths are in the same region as the VM (except train/validation source URLs).
227+ Returns the updated pod config, the ready-to-use train config, the
228+ environment dict, and the Fray extras list.
237229 """
238230 default_launch_config = levanter .infra .cli_helpers .load_config ()
239231
@@ -245,7 +237,6 @@ def run_levanter_train_lm(config: TrainLmOnPodConfig):
245237 config .env_vars or {},
246238 default_launch_config .env_for_accel (config .resources .device .variant ),
247239 )
248- # if we're on tpu, ensure we have wandb
249240 if isinstance (config .resources .device , TpuConfig ):
250241 _check_for_wandb_key (env )
251242
@@ -261,16 +252,6 @@ def run_levanter_train_lm(config: TrainLmOnPodConfig):
261252 config = _enforce_run_id (config )
262253 logger .info (f"Using run ID: { config .train_config .trainer .id } " )
263254
264- model_config = config .train_config .model
265- logger .info (
266- "Model config: type=%s seq_len=%d hidden=%d batch=%s device=%s" ,
267- type (model_config ).__name__ ,
268- model_config .max_seq_len ,
269- model_config .Embed .size ,
270- config .train_config .trainer .train_batch_size ,
271- config .resources .device ,
272- )
273-
274255 train_config = config .train_config
275256 train_config = _suppress_ray_config (train_config )
276257 train_config = _maybe_override_auto_build_caches (train_config , config .auto_build_caches )
@@ -283,87 +264,93 @@ def run_levanter_train_lm(config: TrainLmOnPodConfig):
283264 if not isinstance (config .resources .device , CpuConfig ):
284265 _doublecheck_paths (config )
285266
286- client = current_client ()
287-
288- extras = []
267+ extras : list [str ] = []
289268 if isinstance (config .resources .device , TpuConfig ):
290269 extras .append ("tpu" )
291270 elif isinstance (config .resources .device , GpuConfig ):
292271 extras .append ("gpu" )
293272
294- # Note: Using a constant job name allows restarts to adopt the existing job handle
273+ return config , train_config , env , extras
274+
275+
276+ def _submit_training_job (
277+ * ,
278+ job_name : str ,
279+ main_fn : Callable ,
280+ train_config : TrainConfigT ,
281+ resources : ResourceConfig ,
282+ env : dict [str , str ],
283+ extras : list [str ],
284+ ) -> None :
285+ """Submit a Levanter training job to Fray and block until completion."""
286+ client = current_client ()
287+ # Using a constant job name allows restarts to adopt the existing job handle
295288 # instead of raising a duplicate name error (adopt_existing=True is the default).
296289 job_request = JobRequest (
297- name = "train_lm" ,
298- entrypoint = Entrypoint .from_callable (train_lm . main , args = [train_config ]),
299- resources = config . resources ,
290+ name = job_name ,
291+ entrypoint = Entrypoint .from_callable (main_fn , args = [train_config ]),
292+ resources = resources ,
300293 environment = create_environment (env_vars = env , extras = extras ),
301294 max_retries_failure = 10 ,
302295 )
303296 job = client .submit (job_request )
304297 job .wait (raise_on_failure = True )
305298
306299
307- def run_levanter_train_dpo (config : TrainDpoOnPodConfig ):
308- """
309- Run the Levanter DPO training main function on a Ray cluster.
300+ def run_levanter_train_lm (config : TrainLmOnPodConfig ):
301+ """Run the Levanter LM training main function on a Ray cluster.
310302
311303 This function is designed to be run on your machine or with sufficient variables in the env dict/os env.
312304 It should also be run with a Ray cluster already running.
313- """
314- default_launch_config = levanter .infra .cli_helpers .load_config ()
315305
316- if config .output_path is not None :
317- logger .info (f"Using output path: { config .output_path } " )
318- config = _update_config_to_use_out_path (config )
319-
320- env = _add_default_env_variables (
321- config .env_vars or {},
322- default_launch_config .env_for_accel (config .resources .device .variant ),
323- )
324- if isinstance (config .resources .device , TpuConfig ):
325- _check_for_wandb_key (env )
326-
327- env = _add_run_env_variables (env )
328-
329- if "JAX_COMPILATION_CACHE_DIR" not in env :
330- env ["JAX_COMPILATION_CACHE_DIR" ] = _normalize_jax_compilation_cache_dir (
331- marin_temp_bucket (ttl_days = 30 , prefix = "compilation-cache" )
332- )
333- logger .info ("JAX compilation cache: %s" , env ["JAX_COMPILATION_CACHE_DIR" ])
334- _disable_xla_autotune_subcache (env )
306+ - WANDB_API_KEY: The API key for Weights and Biases.
307+ - RUN_ID: (Optional) The run ID for this training run. Will default to a random UID if not set.
308+ - GIT_COMMIT: (Optional) The git commit hash of the current codebase. Will attempt to fetch it if not set.
335309
336- config = _enforce_run_id (config )
337- logger .info (f"Using run ID: { config .train_config .trainer .id } " )
310+ This function makes a number of changes to the config and ensures a few things are set:
311+ - The run ID is set, or sets a default if not.
312+ - WANDB_API_KEY is set.
313+ - It disables the auto-ray-start and auto-worker-start options since we're already in a Ray cluster.
314+ - It checks that configured GCS paths are in the same region as the VM (except train/validation source URLs).
315+ """
316+ config , train_config , env , extras = _prepare_training_run (config )
338317
339- train_config = config .train_config
340- train_config = _suppress_ray_config (train_config )
341- train_config = _maybe_override_auto_build_caches (train_config , config .auto_build_caches )
318+ model_config = train_config .model
319+ logger .info (
320+ "Model config: type=%s seq_len=%d hidden=%d batch=%s device=%s" ,
321+ type (model_config ).__name__ ,
322+ model_config .max_seq_len ,
323+ model_config .Embed .size ,
324+ train_config .trainer .train_batch_size ,
325+ config .resources .device ,
326+ )
342327
343- if config .resources .device .kind == "cpu" :
344- trainer = replace (train_config .trainer , require_accelerator = False )
345- train_config = replace (train_config , trainer = trainer )
328+ _submit_training_job (
329+ job_name = "train_lm" ,
330+ main_fn = train_lm .main ,
331+ train_config = train_config ,
332+ resources = config .resources ,
333+ env = env ,
334+ extras = extras ,
335+ )
346336
347- if not isinstance (config .resources .device , CpuConfig ):
348- _doublecheck_paths (config )
349337
350- client = current_client ()
338+ def run_levanter_train_dpo (config : TrainDpoOnPodConfig ):
339+ """Run the Levanter DPO training main function on a Ray cluster.
351340
352- extras = []
353- if isinstance (config .resources .device , TpuConfig ):
354- extras .append ("tpu" )
355- elif isinstance (config .resources .device , GpuConfig ):
356- extras .append ("gpu" )
341+ This function is designed to be run on your machine or with sufficient variables in the env dict/os env.
342+ It should also be run with a Ray cluster already running.
343+ """
344+ config , train_config , env , extras = _prepare_training_run (config )
357345
358- job_request = JobRequest (
359- name = "train_dpo" ,
360- entrypoint = Entrypoint .from_callable (train_dpo .main , args = [train_config ]),
346+ _submit_training_job (
347+ job_name = "train_dpo" ,
348+ main_fn = train_dpo .main ,
349+ train_config = train_config ,
361350 resources = config .resources ,
362- environment = create_environment ( env_vars = env , extras = extras ) ,
363- max_retries_failure = 10 ,
351+ env = env ,
352+ extras = extras ,
364353 )
365- job = client .submit (job_request )
366- job .wait (raise_on_failure = True )
367354
368355
369356def _doublecheck_paths (config : TrainOnPodConfigT ):
0 commit comments