Skip to content

Commit e34cc94

Browse files
committed
refactor: run marin-on-iris test as standalone script for streaming logs
Replace the pytest wrapper with a standalone script that streams logs in real time. The test takes ~10 minutes and pytest swallows all output until completion, making failures hard to diagnose. The script runs as its own workflow step with `stream_logs=True` on the Iris job handle, so executor and child job output appears immediately. Inlines create_steps rather than importing from tests.integration_test to avoid sys.path hacks and to allow independent evolution of the pipeline steps for this test.
1 parent c54fb0d commit e34cc94

4 files changed

Lines changed: 334 additions & 142 deletions

File tree

.github/workflows/iris-coreweave-ci.yaml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,21 @@ jobs:
151151
-o "addopts=" \
152152
-x
153153
154+
- name: Run full integration pipeline
155+
env:
156+
WANDB_MODE: disabled
157+
WANDB_API_KEY: ""
158+
JAX_TRACEBACK_FILTERING: off
159+
MARIN_CI_S3_PREFIX: s3://marin-na/temp/ci
160+
AWS_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
161+
AWS_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
162+
AWS_ENDPOINT_URL: https://74981a43be0de7712369306c7b19133d.r2.cloudflarestorage.com
163+
FSSPEC_S3: '{"endpoint_url": "https://74981a43be0de7712369306c7b19133d.r2.cloudflarestorage.com"}'
164+
run: |
165+
IRIS_CONTROLLER_URL="http://localhost:10000"
166+
timeout 600 uv run tests/integration/iris/run_iris_full_integration.py \
167+
--controller-url "$IRIS_CONTROLLER_URL"
168+
154169
- name: Stop port-forward
155170
if: always()
156171
run: |

.github/workflows/iris-integration.yaml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ jobs:
7777
WANDB_API_KEY: ""
7878
JAX_TRACEBACK_FILTERING: off
7979

80+
- name: Run full integration pipeline
81+
run: |
82+
timeout 600 uv run tests/integration/iris/run_iris_full_integration.py \
83+
--controller-url "$IRIS_CONTROLLER_URL"
84+
env:
85+
WANDB_MODE: disabled
86+
WANDB_API_KEY: ""
87+
JAX_TRACEBACK_FILTERING: off
88+
8089
- name: Stop cluster
8190
if: always()
8291
run: kill $CLUSTER_PID 2>/dev/null || true
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Full marin data pipeline integration test on an Iris cluster.
5+
6+
Standalone script (not pytest) so logs stream in real time.
7+
8+
Usage:
9+
uv run tests/integration/iris/run_iris_full_integration.py \
10+
--controller-url http://localhost:10000
11+
12+
When MARIN_CI_S3_PREFIX is set, uploads test fixtures to S3 and submits
13+
the executor as an Iris job so child jobs inherit S3 credentials.
14+
Otherwise runs in-process against local filesystem.
15+
"""
16+
17+
import argparse
18+
import logging
19+
import os
20+
import shutil
21+
import sys
22+
import tempfile
23+
import uuid
24+
from pathlib import Path
25+
26+
import fsspec
27+
from fray import ResourceConfig, set_current_client
28+
from fray.v2.iris_backend import FrayIrisClient
29+
from fray.v2.types import Entrypoint, JobRequest, create_environment
30+
from iris.logging import configure_logging
31+
from levanter.main.train_lm import TrainLmConfig
32+
from levanter.models.gpt2 import Gpt2Config
33+
from levanter.trainer import TrainerConfig
34+
from marin.execution.executor import (
35+
ExecutorMainConfig,
36+
ExecutorStep,
37+
executor_main,
38+
this_output_path,
39+
)
40+
from marin.execution.step_spec import StepSpec
41+
from marin.processing.classification.consolidate import ConsolidateConfig, FilterConfig, FilterType, consolidate
42+
from marin.processing.classification.deduplication.exact import dedup_exact_paragraph
43+
from marin.processing.classification.deduplication.fuzzy import dedup_fuzzy_document
44+
from marin.processing.tokenize import lm_data_config
45+
from marin.processing.tokenize.tokenize import TokenizeConfig, tokenize
46+
from marin.schemas.web.convert import ResiliparseConfig
47+
from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm
48+
from marin.transform.simple_html_to_md.process import SimpleHtmlToMdConfig, html_to_md
49+
50+
configure_logging(level=logging.INFO)
51+
logger = logging.getLogger(__name__)
52+
53+
REPO_ROOT = Path(__file__).resolve().parents[3]
54+
LOCAL_SYNTH_DATA = REPO_ROOT / "tests" / "quickstart-data"
55+
56+
_S3_ENV_KEYS = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_ENDPOINT_URL", "FSSPEC_S3"]
57+
58+
59+
def create_steps(prefix: str, synth_data: str) -> list[ExecutorStep]:
60+
"""Build the full marin data pipeline as executor steps."""
61+
62+
# Transform HTML to markdown
63+
transform_hq_data_spec = StepSpec(
64+
name=os.path.join(prefix, "hq-transformed"),
65+
hash_attrs={"extract_method": "resiliparse"},
66+
fn=lambda output_path: html_to_md(
67+
SimpleHtmlToMdConfig(
68+
input_path=os.path.join(synth_data, "pos"),
69+
output_path=output_path,
70+
extract_method="resiliparse",
71+
config=ResiliparseConfig(),
72+
)
73+
),
74+
)
75+
transform_lq_data_spec = StepSpec(
76+
name=os.path.join(prefix, "lq-transformed"),
77+
hash_attrs={"extract_method": "resiliparse"},
78+
fn=lambda output_path: html_to_md(
79+
SimpleHtmlToMdConfig(
80+
input_path=os.path.join(synth_data, "neg"),
81+
output_path=output_path,
82+
extract_method="resiliparse",
83+
config=ResiliparseConfig(),
84+
)
85+
),
86+
)
87+
transform_hq_data_step = transform_hq_data_spec.as_executor_step()
88+
transform_lq_data_step = transform_lq_data_spec.as_executor_step()
89+
90+
# Dedup
91+
dedup_exact_paragraph_spec = StepSpec(
92+
name=os.path.join(prefix, "dedup_exact_paragraph"),
93+
hash_attrs={"mode": "exact_paragraph"},
94+
deps=[transform_hq_data_spec],
95+
fn=lambda output_path: dedup_exact_paragraph(
96+
input_paths=transform_hq_data_spec.output_path,
97+
output_path=output_path,
98+
max_parallelism=4,
99+
worker_resources=ResourceConfig(cpu=1, ram="1g"),
100+
),
101+
)
102+
dedup_fuzzy_document_spec = StepSpec(
103+
name=os.path.join(prefix, "dedup_fuzzy_document"),
104+
hash_attrs={"mode": "fuzzy_document"},
105+
deps=[transform_hq_data_spec],
106+
fn=lambda output_path: dedup_fuzzy_document(
107+
input_paths=transform_hq_data_spec.output_path,
108+
output_path=output_path,
109+
max_parallelism=4,
110+
worker_resources=ResourceConfig(cpu=1, ram="1g"),
111+
),
112+
)
113+
dedup_exact_paragraph_step = dedup_exact_paragraph_spec.as_executor_step()
114+
dedup_fuzzy_document_step = dedup_fuzzy_document_spec.as_executor_step()
115+
116+
# Consolidate
117+
consolidate_step = ExecutorStep(
118+
name=os.path.join(prefix, "cleaned"),
119+
fn=consolidate,
120+
config=ConsolidateConfig(
121+
input_path=transform_hq_data_step,
122+
output_path=this_output_path(),
123+
filters=[
124+
FilterConfig(
125+
type=FilterType.REMOVE_SPANS,
126+
attribute_path=dedup_exact_paragraph_step.cd("data"),
127+
name="dup_spans",
128+
attribute_filetype="vortex",
129+
keep_if_missing=True,
130+
),
131+
FilterConfig(
132+
type=FilterType.REMOVE_DOC,
133+
attribute_path=dedup_fuzzy_document_step.cd("data"),
134+
name="dup_doc",
135+
attribute_filetype="vortex",
136+
keep_if_missing=True,
137+
),
138+
],
139+
),
140+
)
141+
142+
# Tokenize
143+
tokenize_step = ExecutorStep(
144+
name=os.path.join(prefix, "tokenized"),
145+
fn=tokenize,
146+
config=TokenizeConfig(
147+
train_paths=[consolidate_step],
148+
validation_paths=[],
149+
cache_path=this_output_path(),
150+
tokenizer="gpt2",
151+
),
152+
)
153+
154+
# Train (tiny model for validation)
155+
train_step = ExecutorStep(
156+
name=os.path.join(prefix, "train"),
157+
fn=run_levanter_train_lm,
158+
config=TrainLmOnPodConfig(
159+
output_path=this_output_path(),
160+
resources=ResourceConfig.with_cpu(),
161+
env_vars={
162+
"WANDB_API_KEY": "",
163+
"WANDB_MODE": "disabled",
164+
"JAX_TRACEBACK_FILTERING": "off",
165+
},
166+
train_config=TrainLmConfig(
167+
data=lm_data_config(tokenize_step),
168+
hf_save_steps=1,
169+
model=Gpt2Config(
170+
num_layers=2,
171+
num_heads=2,
172+
max_seq_len=64,
173+
hidden_dim=32,
174+
),
175+
trainer=TrainerConfig(
176+
train_batch_size=8, num_train_steps=2, max_eval_batches=1, require_accelerator=False
177+
),
178+
),
179+
),
180+
)
181+
182+
return [
183+
transform_hq_data_step,
184+
transform_lq_data_step,
185+
dedup_exact_paragraph_step,
186+
dedup_fuzzy_document_step,
187+
consolidate_step,
188+
tokenize_step,
189+
train_step,
190+
]
191+
192+
193+
# ---------------------------------------------------------------------------
194+
# S3 helpers
195+
# ---------------------------------------------------------------------------
196+
197+
198+
def _upload_tree(local_root: Path, s3_dest: str) -> None:
199+
fs, _ = fsspec.core.url_to_fs(s3_dest)
200+
for path in local_root.rglob("*"):
201+
if not path.is_file():
202+
continue
203+
rel = path.relative_to(local_root)
204+
fs.put(str(path), f"{s3_dest}/{rel}")
205+
206+
207+
def _rm_s3(s3_prefix: str) -> None:
208+
fs, _ = fsspec.core.url_to_fs(s3_prefix)
209+
try:
210+
fs.rm(s3_prefix, recursive=True)
211+
except FileNotFoundError:
212+
pass
213+
214+
215+
def _s3_env_vars() -> dict[str, str]:
216+
return {k: os.environ[k] for k in _S3_ENV_KEYS if k in os.environ}
217+
218+
219+
# ---------------------------------------------------------------------------
220+
# Executor entry point (runs inside the Iris job on remote clusters)
221+
# ---------------------------------------------------------------------------
222+
223+
224+
def _run_executor(prefix: str, synth_data: str) -> None:
225+
config = ExecutorMainConfig(
226+
prefix=prefix,
227+
executor_info_base_path=f"{prefix}/experiments",
228+
)
229+
steps = create_steps("quickstart-tests", synth_data)
230+
executor_main(config, steps=steps)
231+
232+
233+
# ---------------------------------------------------------------------------
234+
# Main
235+
# ---------------------------------------------------------------------------
236+
237+
238+
def main():
239+
parser = argparse.ArgumentParser(description="Run full marin pipeline on Iris")
240+
parser.add_argument("--controller-url", required=True)
241+
args = parser.parse_args()
242+
243+
s3_base = os.environ.get("MARIN_CI_S3_PREFIX")
244+
245+
if s3_base:
246+
run_id = f"marin-itest-{uuid.uuid4().hex[:8]}"
247+
prefix = f"{s3_base}/{run_id}"
248+
synth_data = f"{prefix}/quickstart-data"
249+
logger.info("Uploading test fixtures to %s", synth_data)
250+
_upload_tree(LOCAL_SYNTH_DATA, synth_data)
251+
cleanup = lambda: _rm_s3(prefix) # noqa: E731
252+
else:
253+
prefix = tempfile.mkdtemp(prefix="iris-marin-itest-")
254+
synth_data = str(LOCAL_SYNTH_DATA)
255+
cleanup = lambda: shutil.rmtree(prefix, ignore_errors=True) # noqa: E731
256+
257+
os.environ["MARIN_PREFIX"] = prefix
258+
os.environ["WANDB_MODE"] = "disabled"
259+
os.environ["WANDB_API_KEY"] = ""
260+
os.environ["JAX_TRACEBACK_FILTERING"] = "off"
261+
262+
try:
263+
iris_client = FrayIrisClient(
264+
controller_address=args.controller_url,
265+
workspace=REPO_ROOT,
266+
)
267+
268+
if s3_base:
269+
logger.info("Submitting executor as Iris job (S3 mode)")
270+
env_vars = {
271+
"MARIN_PREFIX": prefix,
272+
"WANDB_MODE": "disabled",
273+
"WANDB_API_KEY": "",
274+
"JAX_TRACEBACK_FILTERING": "off",
275+
**_s3_env_vars(),
276+
}
277+
278+
with set_current_client(iris_client):
279+
handle = iris_client.submit(
280+
JobRequest(
281+
name=f"marin-itest-{uuid.uuid4().hex[:8]}",
282+
entrypoint=Entrypoint.from_callable(
283+
_run_executor,
284+
args=(prefix, synth_data),
285+
),
286+
resources=ResourceConfig.with_cpu(),
287+
environment=create_environment(env_vars=env_vars),
288+
)
289+
)
290+
handle.wait(raise_on_failure=True, stream_logs=True)
291+
else:
292+
logger.info("Running executor in-process (local mode)")
293+
config = ExecutorMainConfig(
294+
prefix=prefix,
295+
executor_info_base_path=f"{prefix}/experiments",
296+
)
297+
steps = create_steps("quickstart-tests", synth_data)
298+
with set_current_client(iris_client):
299+
executor_main(config, steps=steps)
300+
301+
logger.info("Pipeline completed successfully")
302+
except Exception:
303+
logger.exception("Pipeline failed")
304+
sys.exit(1)
305+
finally:
306+
cleanup()
307+
308+
309+
if __name__ == "__main__":
310+
main()

0 commit comments

Comments
 (0)