Skip to content

Commit 7db6a14

Browse files
committed
marin: require pydantic BaseModel for typed Artifact.from_path
Drops the dataclass branch from typed loading — `artifact_type` must now be a pydantic BaseModel subclass, otherwise `from_path` raises TypeError. All production call sites already pass pydantic types (`NormalizedData`, `MinHashAttrData`, `FuzzyDupsAttrData`); test types `TokenizeMetadata` / `TrainMetadata` are converted to BaseModel. Save still accepts dataclasses since they round-trip fine through untyped `from_path`.
1 parent 6964fb6 commit 7db6a14

2 files changed

Lines changed: 6 additions & 21 deletions

File tree

lib/marin/src/marin/execution/artifact.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,9 @@ def from_path(
4848
with open_url(f"{base_path}/{cls.__artifact_file_name}", "rb") as fd:
4949
if artifact_type is None:
5050
return json.load(fd)
51-
if issubclass(artifact_type, BaseModel):
52-
return artifact_type.model_validate_json(fd.read())
53-
if is_dataclass(artifact_type):
54-
return artifact_type(**json.load(fd)) # type: ignore[not-callable]
55-
raise ValueError(f"Unsupported artifact type: {artifact_type!r}")
51+
if not issubclass(artifact_type, BaseModel):
52+
raise TypeError(f"artifact_type must be a pydantic BaseModel subclass, got {artifact_type!r}")
53+
return artifact_type.model_validate_json(fd.read())
5654
except FileNotFoundError:
5755
return cls._from_executor_status(base_path, artifact_type)
5856

tests/execution/test_step_runner.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,20 @@
1717
from marin.execution.remote import RemoteCallable, remote
1818
from marin.execution.step_runner import StepRunner
1919
from marin.execution.step_spec import StepSpec
20+
from pydantic import BaseModel
2021
from rigging.filesystem import MARIN_CROSS_REGION_OVERRIDE_ENV
2122

2223
# ---------------------------------------------------------------------------
2324
# Artifact types
2425
# ---------------------------------------------------------------------------
2526

2627

27-
@dataclass
28-
class TokenizeMetadata:
28+
class TokenizeMetadata(BaseModel):
2929
path: str
3030
num_tokens: int
3131

3232

33-
@dataclass
34-
class TrainMetadata:
33+
class TrainMetadata(BaseModel):
3534
tokens_seen: int
3635
checkpoint_path: str
3736

@@ -158,18 +157,6 @@ def test_artifact_from_executor_status_non_success_raises(tmp_path: Path):
158157
Artifact.from_path(tmp_path.as_posix())
159158

160159

161-
def test_artifact_from_executor_status_relative_path(tmp_path: Path, monkeypatch):
162-
"""Fallback also works through MARIN_PREFIX-resolved relative paths."""
163-
monkeypatch.setenv("MARIN_PREFIX", tmp_path.as_posix())
164-
step_dir = tmp_path / "step_out"
165-
step_dir.mkdir()
166-
(step_dir / ".executor_status").write_text("SUCCESS")
167-
168-
loaded = Artifact.from_path("step_out")
169-
assert isinstance(loaded, PathMetadata)
170-
assert loaded.path == step_dir.as_posix()
171-
172-
173160
def test_artifact_save_and_load_untyped(tmp_path: Path):
174161
artifact = TokenizeMetadata(path="/tokenized", num_tokens=42)
175162
Artifact.save(artifact, tmp_path.as_posix())

0 commit comments

Comments
 (0)