Skip to content

Commit 536aac8

Browse files
committed
Create convenience decorator as_step_fn for building StepSpecs
This allows for decorating a function that you want to use in a `StepSpec` and avoiding the `lambda op: ...` boilerplate.
1 parent 30f6b6c commit 536aac8

3 files changed

Lines changed: 18 additions & 3 deletions

File tree

experiments/dedup/fineweb_10bt_exact.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,8 @@ def build_steps() -> list[StepSpec]:
3333
name="exact_dedup_fineweb_10bt",
3434
output_path_prefix=f"{marin_prefix()}/tmp/{OUTPUT_PREFIX}",
3535
deps=[download],
36-
fn=lambda op: dedup_exact_paragraph(
36+
fn=dedup_exact_paragraph(
3737
input_paths=os.path.join(download.output_path, "sample/10BT"),
38-
output_path=op,
3938
max_parallelism=128,
4039
),
4140
)

lib/marin/src/marin/execution/step_spec.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,27 @@
88
import json
99
from collections.abc import Callable
1010
from dataclasses import dataclass
11-
from functools import cached_property
11+
from functools import cached_property, wraps
1212
from typing import Any
1313
from urllib.parse import urlparse
1414

1515
from rigging.filesystem import marin_prefix
1616

1717

18+
def as_step_fn(fn: Callable[..., Any]) -> Callable[..., Any]:
19+
"""Decorator that allows a function to either be called normally with output_path,
20+
or curried by omitting output_path, returning a Callable[[str], Any] suitable for StepSpec.fn."""
21+
22+
@wraps(fn)
23+
def wrapper(*args: Any, **kwargs: Any) -> Any:
24+
def inner(output_path: str) -> Any:
25+
return fn(*args, output_path=output_path, **kwargs)
26+
27+
return inner
28+
29+
return wrapper
30+
31+
1832
def _is_relative_path(url_or_path: str) -> bool:
1933
"""Return True if the path is relative (not a URL and doesn't start with /)."""
2034
if urlparse(url_or_path).scheme:

lib/marin/src/marin/processing/classification/deduplication/exact.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from fray import ResourceConfig
2424
from zephyr import ZephyrContext, counters, write_parquet_file
2525
from zephyr.dataset import Dataset
26+
from marin.execution.step_spec import as_step_fn
2627

2728
logger = logging.getLogger(__name__)
2829

@@ -46,6 +47,7 @@ def _iter_has_more_than_one(records: Iterator[T]) -> tuple[bool, T, Iterator[T]]
4647
return has_more_than_one, first, itertools.chain([first], rest)
4748

4849

50+
@as_step_fn
4951
def dedup_exact_paragraph(
5052
*,
5153
input_paths: str | list[str],

0 commit comments

Comments
 (0)