Skip to content

Commit 2460770

Browse files
authored
marin: port export/levanter_checkpoint.py to fray v2 (#4980)
🤖 ## Summary Ports `lib/marin/src/marin/export/levanter_checkpoint.py` from `fray.v1.cluster.*` to `fray.v2`, mirroring the pattern in `lib/marin/src/marin/evaluation/log_probs.py` (already on v2). Public API — `ConvertCheckpointStepConfig`, `convert_checkpoint_to_hf`, `convert_checkpoint_to_hf_step` — is unchanged. This is Bucket 2 PR-B in the Ray-removal roadmap. Part of [#4453](#4453). ## Why medium-risk This file backs the Marin HF model-release pipeline (`experiments/tootsie/exp_1246_upload_datasets.py` and `experiments/tootsie/exp1984_convert_32b_phases.py`). The port preserves every public-facing signature; only the job-submission internals change. ## Test plan - [x] `./infra/pre-commit.py --all-files --fix` — green. - [x] `uv run pyrefly` — no new errors vs `origin/main` baseline. - [x] Import check: `marin.export` + both tootsie experiments import cleanly. - [x] Any `marin.export`-importing tests pass. - [x] No `fray.v1` references remain under `lib/marin/src/marin/export/`.
1 parent a2f4a1a commit 2460770

1 file changed

Lines changed: 14 additions & 8 deletions

File tree

lib/marin/src/marin/export/levanter_checkpoint.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@
77
from typing import Any
88

99
import levanter.infra.cli_helpers
10-
from fray.v1.cluster import (
10+
from fray.v2 import current_client
11+
from fray.v2.types import (
1112
CpuConfig,
1213
Entrypoint,
13-
EnvironmentConfig,
14+
GpuConfig,
1415
JobRequest,
1516
ResourceConfig,
1617
TpuConfig,
17-
current_cluster,
18+
create_environment,
1819
)
1920
from levanter.checkpoint import discover_latest_checkpoint
2021
from levanter.compat.hf_checkpoints import RepoRef
@@ -109,16 +110,21 @@ def _run_with_lockfile():
109110
if isinstance(config.resources.device, TpuConfig):
110111
assert config.resources.replicas == 1, "Export currently works on single slices at present."
111112

113+
extras: list[str] = []
114+
if isinstance(config.resources.device, TpuConfig):
115+
extras.append("tpu")
116+
elif isinstance(config.resources.device, GpuConfig):
117+
extras.append("gpu")
118+
119+
client = current_client()
112120
job_request = JobRequest(
113121
name="convert-checkpoint-to-hf",
114122
entrypoint=Entrypoint.from_callable(_run_with_lockfile),
115123
resources=config.resources,
116-
environment=EnvironmentConfig.create(env_vars=env),
124+
environment=create_environment(env_vars=env, extras=extras),
117125
)
118-
119-
cluster = current_cluster()
120-
job_id = cluster.launch(job_request)
121-
cluster.wait(job_id, raise_on_failure=True)
126+
job = client.submit(job_request)
127+
job.wait(raise_on_failure=True)
122128

123129

124130
def convert_checkpoint_to_hf_step(

0 commit comments

Comments
 (0)