Skip to content

Commit 34f2e2d

Browse files
github-actions[bot]Nightshift Agentclaude
authored
[nightshift] Deduplicate LM/DPO training setup in marin (#4255)
> *Forked paths converge—* > *shared roots beneath the diff* > *one function to bind* ## Summary - Extracted `_prepare_training_run()` and `_submit_training_job()` from `run_levanter_train_lm` and `run_levanter_train_dpo`, which were ~50-line near-identical copies of each other - Both public functions now delegate to the shared helpers, keeping only their unique logic (LM logs model config details; DPO does not) - Net reduction of ~13 lines and elimination of a maintenance hazard where fixes to one path could easily be missed in the other ## Test plan - [x] `uv run --package marin pytest tests/test_training.py -x` — 4/4 pass - [x] `./infra/pre-commit.py --all-files --fix` — clean - [ ] Manual: verify experiment pipelines that call `run_levanter_train_lm` / `run_levanter_train_dpo` still work end-to-end 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Nightshift Agent <nightshift-agent@marin-community.github.io> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 2dcc1b1 commit 34f2e2d

File tree

1 file changed

+67
-80
lines changed

1 file changed

+67
-80
lines changed

lib/marin/src/marin/training/training.py

Lines changed: 67 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
from copy import deepcopy
88
from dataclasses import dataclass, replace
9+
from collections.abc import Callable
910
from typing import TypeVar
1011

1112
import draccus
@@ -218,22 +219,13 @@ def _disable_xla_autotune_subcache(env: dict) -> None:
218219
logger.info("XLA sub-caches disabled (compilation cache is remote: %s)", cache_dir)
219220

220221

221-
def run_levanter_train_lm(config: TrainLmOnPodConfig):
222-
"""
223-
Run the Levanter training main function on a Ray cluster.
224-
225-
This function is designed to be run on your machine or with sufficient variables in the env dict/os env.
226-
It should also be run with a Ray cluster already running.
227-
228-
- WANDB_API_KEY: The API key for Weights and Biases.
229-
- RUN_ID: (Optional) The run ID for this training run. Will default to a random UID if not set.
230-
- GIT_COMMIT: (Optional) The git commit hash of the current codebase. Will attempt to fetch it if not set.
222+
def _prepare_training_run(
223+
config: TrainOnPodConfigT,
224+
) -> tuple[TrainOnPodConfigT, TrainLmConfig | TrainDpoConfig, dict[str, str], list[str]]:
225+
"""Shared setup for LM and DPO training: env vars, run ID, config adjustments.
231226
232-
This function makes a number of changes to the config and ensures a few things are set:
233-
- The run ID is set, or sets a default if not.
234-
- WANDB_API_KEY is set.
235-
- It disables the auto-ray-start and auto-worker-start options since we're already in a Ray cluster.
236-
- It checks that configured GCS paths are in the same region as the VM (except train/validation source URLs).
227+
Returns the updated pod config, the ready-to-use train config, the
228+
environment dict, and the Fray extras list.
237229
"""
238230
default_launch_config = levanter.infra.cli_helpers.load_config()
239231

@@ -245,7 +237,6 @@ def run_levanter_train_lm(config: TrainLmOnPodConfig):
245237
config.env_vars or {},
246238
default_launch_config.env_for_accel(config.resources.device.variant),
247239
)
248-
# if we're on tpu, ensure we have wandb
249240
if isinstance(config.resources.device, TpuConfig):
250241
_check_for_wandb_key(env)
251242

@@ -261,16 +252,6 @@ def run_levanter_train_lm(config: TrainLmOnPodConfig):
261252
config = _enforce_run_id(config)
262253
logger.info(f"Using run ID: {config.train_config.trainer.id}")
263254

264-
model_config = config.train_config.model
265-
logger.info(
266-
"Model config: type=%s seq_len=%d hidden=%d batch=%s device=%s",
267-
type(model_config).__name__,
268-
model_config.max_seq_len,
269-
model_config.Embed.size,
270-
config.train_config.trainer.train_batch_size,
271-
config.resources.device,
272-
)
273-
274255
train_config = config.train_config
275256
train_config = _suppress_ray_config(train_config)
276257
train_config = _maybe_override_auto_build_caches(train_config, config.auto_build_caches)
@@ -283,87 +264,93 @@ def run_levanter_train_lm(config: TrainLmOnPodConfig):
283264
if not isinstance(config.resources.device, CpuConfig):
284265
_doublecheck_paths(config)
285266

286-
client = current_client()
287-
288-
extras = []
267+
extras: list[str] = []
289268
if isinstance(config.resources.device, TpuConfig):
290269
extras.append("tpu")
291270
elif isinstance(config.resources.device, GpuConfig):
292271
extras.append("gpu")
293272

294-
# Note: Using a constant job name allows restarts to adopt the existing job handle
273+
return config, train_config, env, extras
274+
275+
276+
def _submit_training_job(
277+
*,
278+
job_name: str,
279+
main_fn: Callable,
280+
train_config: TrainConfigT,
281+
resources: ResourceConfig,
282+
env: dict[str, str],
283+
extras: list[str],
284+
) -> None:
285+
"""Submit a Levanter training job to Fray and block until completion."""
286+
client = current_client()
287+
# Using a constant job name allows restarts to adopt the existing job handle
295288
# instead of raising a duplicate name error (adopt_existing=True is the default).
296289
job_request = JobRequest(
297-
name="train_lm",
298-
entrypoint=Entrypoint.from_callable(train_lm.main, args=[train_config]),
299-
resources=config.resources,
290+
name=job_name,
291+
entrypoint=Entrypoint.from_callable(main_fn, args=[train_config]),
292+
resources=resources,
300293
environment=create_environment(env_vars=env, extras=extras),
301294
max_retries_failure=10,
302295
)
303296
job = client.submit(job_request)
304297
job.wait(raise_on_failure=True)
305298

306299

307-
def run_levanter_train_dpo(config: TrainDpoOnPodConfig):
308-
"""
309-
Run the Levanter DPO training main function on a Ray cluster.
300+
def run_levanter_train_lm(config: TrainLmOnPodConfig):
301+
"""Run the Levanter LM training main function on a Ray cluster.
310302
311303
This function is designed to be run on your machine or with sufficient variables in the env dict/os env.
312304
It should also be run with a Ray cluster already running.
313-
"""
314-
default_launch_config = levanter.infra.cli_helpers.load_config()
315305
316-
if config.output_path is not None:
317-
logger.info(f"Using output path: {config.output_path}")
318-
config = _update_config_to_use_out_path(config)
319-
320-
env = _add_default_env_variables(
321-
config.env_vars or {},
322-
default_launch_config.env_for_accel(config.resources.device.variant),
323-
)
324-
if isinstance(config.resources.device, TpuConfig):
325-
_check_for_wandb_key(env)
326-
327-
env = _add_run_env_variables(env)
328-
329-
if "JAX_COMPILATION_CACHE_DIR" not in env:
330-
env["JAX_COMPILATION_CACHE_DIR"] = _normalize_jax_compilation_cache_dir(
331-
marin_temp_bucket(ttl_days=30, prefix="compilation-cache")
332-
)
333-
logger.info("JAX compilation cache: %s", env["JAX_COMPILATION_CACHE_DIR"])
334-
_disable_xla_autotune_subcache(env)
306+
- WANDB_API_KEY: The API key for Weights and Biases.
307+
- RUN_ID: (Optional) The run ID for this training run. Will default to a random UID if not set.
308+
- GIT_COMMIT: (Optional) The git commit hash of the current codebase. Will attempt to fetch it if not set.
335309
336-
config = _enforce_run_id(config)
337-
logger.info(f"Using run ID: {config.train_config.trainer.id}")
310+
This function makes a number of changes to the config and ensures a few things are set:
311+
- The run ID is set, or sets a default if not.
312+
- WANDB_API_KEY is set.
313+
- It disables the auto-ray-start and auto-worker-start options since we're already in a Ray cluster.
314+
- It checks that configured GCS paths are in the same region as the VM (except train/validation source URLs).
315+
"""
316+
config, train_config, env, extras = _prepare_training_run(config)
338317

339-
train_config = config.train_config
340-
train_config = _suppress_ray_config(train_config)
341-
train_config = _maybe_override_auto_build_caches(train_config, config.auto_build_caches)
318+
model_config = train_config.model
319+
logger.info(
320+
"Model config: type=%s seq_len=%d hidden=%d batch=%s device=%s",
321+
type(model_config).__name__,
322+
model_config.max_seq_len,
323+
model_config.Embed.size,
324+
train_config.trainer.train_batch_size,
325+
config.resources.device,
326+
)
342327

343-
if config.resources.device.kind == "cpu":
344-
trainer = replace(train_config.trainer, require_accelerator=False)
345-
train_config = replace(train_config, trainer=trainer)
328+
_submit_training_job(
329+
job_name="train_lm",
330+
main_fn=train_lm.main,
331+
train_config=train_config,
332+
resources=config.resources,
333+
env=env,
334+
extras=extras,
335+
)
346336

347-
if not isinstance(config.resources.device, CpuConfig):
348-
_doublecheck_paths(config)
349337

350-
client = current_client()
338+
def run_levanter_train_dpo(config: TrainDpoOnPodConfig):
339+
"""Run the Levanter DPO training main function on a Ray cluster.
351340
352-
extras = []
353-
if isinstance(config.resources.device, TpuConfig):
354-
extras.append("tpu")
355-
elif isinstance(config.resources.device, GpuConfig):
356-
extras.append("gpu")
341+
This function is designed to be run on your machine or with sufficient variables in the env dict/os env.
342+
It should also be run with a Ray cluster already running.
343+
"""
344+
config, train_config, env, extras = _prepare_training_run(config)
357345

358-
job_request = JobRequest(
359-
name="train_dpo",
360-
entrypoint=Entrypoint.from_callable(train_dpo.main, args=[train_config]),
346+
_submit_training_job(
347+
job_name="train_dpo",
348+
main_fn=train_dpo.main,
349+
train_config=train_config,
361350
resources=config.resources,
362-
environment=create_environment(env_vars=env, extras=extras),
363-
max_retries_failure=10,
351+
env=env,
352+
extras=extras,
364353
)
365-
job = client.submit(job_request)
366-
job.wait(raise_on_failure=True)
367354

368355

369356
def _doublecheck_paths(config: TrainOnPodConfigT):

0 commit comments

Comments
 (0)