diff --git a/.github/workflows/iris-coreweave-ci.yaml b/.github/workflows/iris-coreweave-ci.yaml new file mode 100644 index 0000000000..1bdec3dbab --- /dev/null +++ b/.github/workflows/iris-coreweave-ci.yaml @@ -0,0 +1,202 @@ +name: Iris - CoreWeave CI + +on: + pull_request: + types: [opened, synchronize] + paths: + - "lib/iris/**" + issue_comment: + types: [created] + workflow_dispatch: + +permissions: + contents: read + packages: write + pull-requests: read # needed for issue_comment to access PR metadata + statuses: write # post commit status from issue_comment trigger + +# Single concurrency group — only one CW CI run at a time across all PRs. +# The warm cluster is shared; concurrent runs would conflict. +concurrency: + group: iris-coreweave-ci + cancel-in-progress: false + +jobs: + cw-ci-test: + if: >- + (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) || + github.event_name == 'workflow_dispatch' || + ( + github.event_name == 'issue_comment' && + github.event.issue.pull_request && + contains(github.event.comment.body, '/iris-ci-cw') && + ( + github.event.comment.author_association == 'MEMBER' || + github.event.comment.author_association == 'COLLABORATOR' || + github.event.comment.author_association == 'OWNER' + ) + ) + runs-on: ubuntu-latest + timeout-minutes: 60 + env: + IRIS_NAMESPACE: iris-ci + # Must match Labels(label_prefix).iris_managed from the cluster config + IRIS_MANAGED_LABEL: iris-iris-ci-managed + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + ref: ${{ github.event_name == 'issue_comment' && format('refs/pull/{0}/head', github.event.issue.number) || '' }} + + - name: Set commit status to pending + if: github.event_name == 'issue_comment' + env: + GH_TOKEN: ${{ github.token }} + run: | + sha=$(git rev-parse HEAD) + gh api repos/${{ github.repository }}/statuses/"$sha" \ + -f state=pending \ + -f context="Iris CoreWeave CI" \ + -f target_url="${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}" || true + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install uv + uses: astral-sh/setup-uv@v7 + with: + enable-cache: true + cache-dependency-glob: "lib/iris/pyproject.toml" + + - name: Write kubeconfig + run: | + mkdir -p ~/.kube + echo "${{ secrets.CW_KUBECONFIG }}" > ~/.kube/coreweave-iris + chmod 600 ~/.kube/coreweave-iris + + - name: Log in to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + # Delete stale worker pods so the autoscaler recreates them with fresh images. + # Nodepools (and their underlying nodes) survive — this is the "warm start". + - name: Reset worker pods + run: | + export KUBECONFIG=~/.kube/coreweave-iris + kubectl delete pods -n "$IRIS_NAMESPACE" -l "$IRIS_MANAGED_LABEL=true" --grace-period=0 --ignore-not-found || true + + # Rebuild images and (re)start the controller. `cluster start` is fully + # idempotent on K8s: it applies namespace/RBAC/ConfigMap/Deployment/Service + # and triggers a rollout restart, so both cold starts and warm restarts + # work without needing to tunnel to an existing controller first. + - name: Start controller + env: + R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} + R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} + run: | + cd lib/iris && uv run --group dev iris -v \ + --config=examples/coreweave-ci.yaml \ + cluster start + + - name: Run integration tests + env: + WANDB_MODE: disabled + WANDB_API_KEY: "" + JAX_TRACEBACK_FILTERING: off + # When set, the marin-on-iris test uploads fixtures and writes + # intermediate data to S3 (R2) so remote Zephyr pods can access them. + MARIN_CI_S3_PREFIX: s3://marin-na/temp/ci + AWS_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} + AWS_ENDPOINT_URL: https://74981a43be0de7712369306c7b19133d.r2.cloudflarestorage.com + FSSPEC_S3: '{"endpoint_url": "https://74981a43be0de7712369306c7b19133d.r2.cloudflarestorage.com"}' + run: | + export KUBECONFIG=~/.kube/coreweave-iris + kubectl port-forward -n "$IRIS_NAMESPACE" svc/iris-ci-controller-svc 10000:10000 & + PF_PID=$! + echo "PF_PID=$PF_PID" >> "$GITHUB_ENV" + + IRIS_CONTROLLER_URL="http://localhost:10000" + + # Controller deployment is already confirmed ready by `cluster start`; + # this just waits for the port-forward to be usable. + HEALTHY=false + for i in $(seq 1 60); do + if ! kill -0 "$PF_PID" 2>/dev/null; then + echo "port-forward process died unexpectedly" + exit 1 + fi + if curl -sf "$IRIS_CONTROLLER_URL/health" > /dev/null 2>&1; then + HEALTHY=true + break + fi + sleep 5 + done + if [ "$HEALTHY" != "true" ]; then + echo "Controller did not become healthy within timeout" + exit 1 + fi + + uv run pytest tests/integration/iris/ \ + --controller-url "$IRIS_CONTROLLER_URL" \ + -v --tb=short --timeout=600 \ + -o "addopts=" \ + -x + + - name: Run full integration pipeline + env: + WANDB_MODE: disabled + WANDB_API_KEY: "" + JAX_TRACEBACK_FILTERING: off + MARIN_CI_S3_PREFIX: s3://marin-na/temp/ci + AWS_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }} + AWS_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }} + AWS_ENDPOINT_URL: https://74981a43be0de7712369306c7b19133d.r2.cloudflarestorage.com + FSSPEC_S3: '{"endpoint_url": "https://74981a43be0de7712369306c7b19133d.r2.cloudflarestorage.com"}' + run: | + IRIS_CONTROLLER_URL="http://localhost:10000" + timeout 600 uv run tests/integration/iris/run_iris_full_integration.py \ + --controller-url "$IRIS_CONTROLLER_URL" + + - name: Stop port-forward + if: always() + run: | + [ -n "$PF_PID" ] && kill "$PF_PID" 2>/dev/null || true + pkill -f "kubectl port-forward.*$IRIS_NAMESPACE" 2>/dev/null || true + + - name: Capture failure diagnostics + if: failure() + run: | + export KUBECONFIG=~/.kube/coreweave-iris + echo "=== Controller logs ===" + kubectl -n "$IRIS_NAMESPACE" logs -l app=iris-controller --tail=500 || true + echo "=== Controller pod describe ===" + kubectl -n "$IRIS_NAMESPACE" describe pod -l app=iris-controller || true + echo "=== Worker pods ===" + kubectl -n "$IRIS_NAMESPACE" get pods -l "$IRIS_MANAGED_LABEL=true" || true + echo "=== Warning events ===" + kubectl -n "$IRIS_NAMESPACE" get events --sort-by='.lastTimestamp' --field-selector type!=Normal || true + + - name: Set commit status to result + if: always() && github.event_name == 'issue_comment' + env: + GH_TOKEN: ${{ github.token }} + run: | + sha=$(git rev-parse HEAD) + if [ "${{ job.status }}" = "success" ]; then + state=success + else + state=failure + fi + gh api repos/${{ github.repository }}/statuses/"$sha" \ + -f state="$state" \ + -f context="Iris CoreWeave CI" \ + -f target_url="${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}" diff --git a/.github/workflows/iris-integration.yaml b/.github/workflows/iris-integration.yaml index 0328460031..d2aa030fe9 100644 --- a/.github/workflows/iris-integration.yaml +++ b/.github/workflows/iris-integration.yaml @@ -69,7 +69,7 @@ jobs: run: | uv run pytest tests/integration/iris/ \ --controller-url "$IRIS_CONTROLLER_URL" \ - -v --tb=short --timeout=600 \ + -v -s --log-cli-level=INFO --tb=short --timeout=600 \ -o "addopts=" \ -x env: @@ -77,6 +77,15 @@ jobs: WANDB_API_KEY: "" JAX_TRACEBACK_FILTERING: off + - name: Run full integration pipeline + run: | + timeout 600 uv run tests/integration/iris/run_iris_full_integration.py \ + --controller-url "$IRIS_CONTROLLER_URL" + env: + WANDB_MODE: disabled + WANDB_API_KEY: "" + JAX_TRACEBACK_FILTERING: off + - name: Stop cluster if: always() run: kill $CLUSTER_PID 2>/dev/null || true diff --git a/lib/iris/examples/coreweave-ci.yaml b/lib/iris/examples/coreweave-ci.yaml new file mode 100644 index 0000000000..4f73535bac --- /dev/null +++ b/lib/iris/examples/coreweave-ci.yaml @@ -0,0 +1,91 @@ +# Persistent CoreWeave CI cluster. Both scale groups are pinned at min=max=1 +# so nodes stay warm between runs — only controller and worker pods are reset. + +platform: + label_prefix: iris-ci + coreweave: + region: US-WEST-04A + namespace: iris-ci + kubeconfig_path: ~/.kube/coreweave-iris + object_storage_endpoint: https://74981a43be0de7712369306c7b19133d.r2.cloudflarestorage.com + +storage: + remote_state_dir: s3://marin-na/iris/state/ci + +kubernetes_provider: + namespace: iris-ci + default_image: ghcr.io/marin-community/iris-task:latest + host_network: true + cache_dir: /mnt/local/iris-cache + controller_address: http://iris-ci-controller-svc.iris-ci.svc.cluster.local:10000 + +controller: + image: ghcr.io/marin-community/iris-controller:latest + coreweave: + port: 10000 + service_name: iris-ci-controller-svc + scale_group: cpu-erapids + +defaults: + autoscaler: + evaluation_interval: + milliseconds: 10000 + scale_up_delay: + milliseconds: 60000 + scale_down_delay: + milliseconds: 300000 + startup_grace_period: + milliseconds: 1200000 # 20 min — nodes are pinned warm so this rarely fires + task_env: + MARIN_PREFIX: s3://marin-na/marin + worker: + docker_image: ghcr.io/marin-community/iris-worker:latest + port: 10001 + cache_dir: /mnt/local/iris-cache + runtime: kubernetes + default_task_image: ghcr.io/marin-community/iris-task:latest + +scale_groups: + cpu-erapids: + num_vms: 1 + resources: + cpu: 64 + ram: 256GB + disk: 1TB + device_type: cpu + preemptible: false + worker: + attributes: + region: US-WEST-04A + pool: cpu-erapids + min_slices: 1 + max_slices: 1 + priority: 50 + slice_template: + num_vms: 1 + coreweave: + region: US-WEST-04A + instance_type: cd-gp-i64-erapids + + h100-8x: + num_vms: 1 + resources: + cpu: 128 + ram: 2048GB + disk: 1TB + device_type: gpu + device_variant: H100 + device_count: 8 + preemptible: false + worker: + attributes: + region: US-WEST-04A + pool: h100-8x + min_slices: 1 + max_slices: 1 + priority: 100 + slice_template: + num_vms: 1 + coreweave: + region: US-WEST-04A + instance_type: gd-8xh100ib-i128 diff --git a/lib/iris/src/iris/cluster/controller/service.py b/lib/iris/src/iris/cluster/controller/service.py index 8d049d678c..4a0e2bf153 100644 --- a/lib/iris/src/iris/cluster/controller/service.py +++ b/lib/iris/src/iris/cluster/controller/service.py @@ -1870,6 +1870,18 @@ def exec_in_container( task_worker_id = task.worker_id if not task_worker_id: + if self._controller.has_direct_provider: + provider = self._controller.provider + timeout = request.timeout_seconds if request.timeout_seconds else 60 + resp = provider.exec_in_container( + task.task_id.to_wire(), task.current_attempt_id, list(request.command), timeout + ) + return cluster_pb2.Controller.ExecInContainerResponse( + exit_code=resp.exit_code, + stdout=resp.stdout, + stderr=resp.stderr, + error=resp.error, + ) raise ConnectError(Code.FAILED_PRECONDITION, f"Task {request.task_id} not assigned to a worker") worker = _read_worker(self._db, task_worker_id) diff --git a/lib/iris/src/iris/cluster/providers/k8s/tasks.py b/lib/iris/src/iris/cluster/providers/k8s/tasks.py index 135261a07f..9c091f4c89 100644 --- a/lib/iris/src/iris/cluster/providers/k8s/tasks.py +++ b/lib/iris/src/iris/cluster/providers/k8s/tasks.py @@ -693,6 +693,26 @@ def profile_task( except Exception as e: return cluster_pb2.ProfileTaskResponse(error=str(e)) + def exec_in_container( + self, + task_id: str, + attempt_id: int, + command: list[str], + timeout_seconds: int = 60, + ) -> cluster_pb2.Worker.ExecInContainerResponse: + """Execute a command in a running task pod via kubectl exec.""" + pod_name = _pod_name(JobName.from_wire(task_id), attempt_id) + effective_timeout: float | None = timeout_seconds if timeout_seconds >= 0 else None + try: + result = self.kubectl.exec(pod_name, command, container="task", timeout=effective_timeout) + return cluster_pb2.Worker.ExecInContainerResponse( + exit_code=result.returncode, + stdout=result.stdout, + stderr=result.stderr, + ) + except Exception as e: + return cluster_pb2.Worker.ExecInContainerResponse(error=str(e)) + def close(self) -> None: """No persistent resources to release.""" diff --git a/lib/marin/src/marin/processing/classification/classifier.py b/lib/marin/src/marin/processing/classification/classifier.py index 580c564f4f..e86a1ee4f4 100644 --- a/lib/marin/src/marin/processing/classification/classifier.py +++ b/lib/marin/src/marin/processing/classification/classifier.py @@ -81,7 +81,7 @@ def load_model(self): with FileLock(lock_file): if not os.path.exists(success_file): - fs.makedirs(f"/tmp/{model_descriptor}", exist_ok=True) + os.makedirs(f"/tmp/{model_descriptor}", exist_ok=True) if is_remote_or_local_path: fs.get(fs_path, local_filepath) diff --git a/lib/marin/src/marin/processing/classification/dataset_utils.py b/lib/marin/src/marin/processing/classification/dataset_utils.py index 50769a64bc..142222f430 100644 --- a/lib/marin/src/marin/processing/classification/dataset_utils.py +++ b/lib/marin/src/marin/processing/classification/dataset_utils.py @@ -74,38 +74,36 @@ def make_json_serializable(row: dict) -> dict: def read_dataset_streaming(input_filename: str, columns: list[str] | None = None): - """Read in a dataset as a streaming iterator using datasets library + """Read in a dataset as a streaming iterator. + + Uses fsspec + json directly instead of HuggingFace datasets to avoid + the datasets CompressionFilesystem injecting aiohttp kwargs + (requote_redirect_url) into botocore's create_client, which breaks + on S3-compatible backends. Args: - input_filename: str - The path to the input file. Currently supports .jsonl.gz, .jsonl.zst, and .parquet + input_filename: Path to the input file (.jsonl.gz, .jsonl.zst, .jsonl, or .parquet). Returns: - Iterator: An iterator over the dataset rows + Iterator over dataset rows as dicts. """ - import datasets - - # Disable caching for streaming - datasets.disable_caching() - datasets.logging.set_verbosity_warning() - - # Determine file type and load with streaming if input_filename.endswith((".jsonl.gz", ".jsonl.zst", ".jsonl")): - # Load as JSON lines with streaming - dataset = datasets.load_dataset("json", data_files=input_filename, streaming=True, split="train") + with open_url(input_filename, "rb", compression="infer") as f: + for line in f: + row = json.loads(line) + if columns: + row = {k: row[k] for k in columns if k in row} + yield row elif input_filename.endswith(".parquet"): - # Load parquet with streaming - dataset = datasets.load_dataset("parquet", data_files=input_filename, streaming=True, split="train") + import pyarrow.parquet as pq + + with open_url(input_filename, "rb") as f: + table = pq.read_table(f, columns=columns) + for row in table.to_pylist(): + yield row else: raise ValueError(f"Unsupported filetype: {input_filename}") - # Filter columns if specified - if columns: - dataset = dataset.select_columns(columns) - - # Yield rows from the streaming dataset - yield from dataset - def write_dataset_streaming(rows_iterator, output_filename: str, append: bool = False): """Writes rows to a file in streaming fashion diff --git a/tests/integration/iris/run_iris_full_integration.py b/tests/integration/iris/run_iris_full_integration.py new file mode 100644 index 0000000000..905913a208 --- /dev/null +++ b/tests/integration/iris/run_iris_full_integration.py @@ -0,0 +1,289 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""Full marin data pipeline integration test on an Iris cluster. + +Standalone script (not pytest) so logs stream in real time. + +Usage: + uv run tests/integration/iris/run_iris_full_integration.py \ + --controller-url http://localhost:10000 + +When MARIN_CI_S3_PREFIX is set, uploads test fixtures to S3 and submits +the executor as an Iris job so child jobs inherit S3 credentials. +Otherwise runs in-process against local filesystem. +""" + +import argparse +import logging +import os +import shutil +import sys +import tempfile +import uuid +from pathlib import Path + +import fsspec +from fray import ResourceConfig, set_current_client +from fray.v2.iris_backend import FrayIrisClient +from fray.v2.types import Entrypoint, JobRequest, create_environment +from iris.logging import configure_logging +from levanter.main.train_lm import TrainLmConfig +from levanter.models.gpt2 import Gpt2Config +from levanter.trainer import TrainerConfig +from marin.execution.executor import ( + ExecutorMainConfig, + ExecutorStep, + executor_main, + this_output_path, +) +from marin.execution.step_spec import StepSpec +from marin.processing.classification.consolidate import ConsolidateConfig, FilterConfig, FilterType, consolidate +from marin.processing.classification.deduplication.exact import dedup_exact_paragraph +from marin.processing.tokenize import lm_data_config +from marin.processing.tokenize.tokenize import TokenizeConfig, tokenize +from marin.schemas.web.convert import ResiliparseConfig +from marin.training.training import TrainLmOnPodConfig, run_levanter_train_lm +from marin.transform.simple_html_to_md.process import SimpleHtmlToMdConfig, html_to_md + +configure_logging(level=logging.INFO) +logger = logging.getLogger(__name__) + +REPO_ROOT = Path(__file__).resolve().parents[3] +LOCAL_SYNTH_DATA = REPO_ROOT / "tests" / "quickstart-data" + +_S3_ENV_KEYS = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_ENDPOINT_URL", "FSSPEC_S3"] + + +def create_steps(prefix: str, synth_data: str) -> list[ExecutorStep]: + """Build the full marin data pipeline as executor steps.""" + + # Transform HTML to markdown + transform_hq_data_spec = StepSpec( + name=os.path.join(prefix, "hq-transformed"), + hash_attrs={"extract_method": "resiliparse"}, + fn=lambda output_path: html_to_md( + SimpleHtmlToMdConfig( + input_path=os.path.join(synth_data, "pos"), + output_path=output_path, + extract_method="resiliparse", + config=ResiliparseConfig(), + ) + ), + ) + transform_lq_data_spec = StepSpec( + name=os.path.join(prefix, "lq-transformed"), + hash_attrs={"extract_method": "resiliparse"}, + fn=lambda output_path: html_to_md( + SimpleHtmlToMdConfig( + input_path=os.path.join(synth_data, "neg"), + output_path=output_path, + extract_method="resiliparse", + config=ResiliparseConfig(), + ) + ), + ) + transform_hq_data_step = transform_hq_data_spec.as_executor_step() + transform_lq_data_step = transform_lq_data_spec.as_executor_step() + + # Dedup (exact only — fuzzy dedup has 4 iterative rounds of pod scheduling on K8s) + dedup_exact_paragraph_spec = StepSpec( + name=os.path.join(prefix, "dedup_exact_paragraph"), + hash_attrs={"mode": "exact_paragraph"}, + deps=[transform_hq_data_spec], + fn=lambda output_path: dedup_exact_paragraph( + input_paths=transform_hq_data_spec.output_path, + output_path=output_path, + max_parallelism=4, + worker_resources=ResourceConfig(cpu=1, ram="1g"), + ), + ) + dedup_exact_paragraph_step = dedup_exact_paragraph_spec.as_executor_step() + + # Consolidate + consolidate_step = ExecutorStep( + name=os.path.join(prefix, "cleaned"), + fn=consolidate, + config=ConsolidateConfig( + input_path=transform_hq_data_step, + output_path=this_output_path(), + filters=[ + FilterConfig( + type=FilterType.REMOVE_SPANS, + attribute_path=dedup_exact_paragraph_step.cd("data"), + name="dup_spans", + attribute_filetype="vortex", + keep_if_missing=True, + ), + ], + ), + ) + + # Tokenize + tokenize_step = ExecutorStep( + name=os.path.join(prefix, "tokenized"), + fn=tokenize, + config=TokenizeConfig( + train_paths=[consolidate_step], + validation_paths=[], + cache_path=this_output_path(), + tokenizer="gpt2", + ), + ) + + # Train (tiny model for validation) + train_step = ExecutorStep( + name=os.path.join(prefix, "train"), + fn=run_levanter_train_lm, + config=TrainLmOnPodConfig( + output_path=this_output_path(), + resources=ResourceConfig.with_cpu(), + env_vars={ + "WANDB_API_KEY": "", + "WANDB_MODE": "disabled", + "JAX_TRACEBACK_FILTERING": "off", + }, + train_config=TrainLmConfig( + data=lm_data_config(tokenize_step), + hf_save_steps=1, + model=Gpt2Config( + num_layers=2, + num_heads=2, + max_seq_len=64, + hidden_dim=32, + ), + trainer=TrainerConfig( + train_batch_size=8, num_train_steps=2, max_eval_batches=1, require_accelerator=False + ), + ), + ), + ) + + return [ + transform_hq_data_step, + transform_lq_data_step, + dedup_exact_paragraph_step, + consolidate_step, + tokenize_step, + train_step, + ] + + +# --------------------------------------------------------------------------- +# S3 helpers +# --------------------------------------------------------------------------- + + +def _upload_tree(local_root: Path, s3_dest: str) -> None: + fs, _ = fsspec.core.url_to_fs(s3_dest) + for path in local_root.rglob("*"): + if not path.is_file(): + continue + rel = path.relative_to(local_root) + fs.put(str(path), f"{s3_dest}/{rel}") + + +def _rm_s3(s3_prefix: str) -> None: + fs, _ = fsspec.core.url_to_fs(s3_prefix) + try: + fs.rm(s3_prefix, recursive=True) + except FileNotFoundError: + pass + + +def _s3_env_vars() -> dict[str, str]: + return {k: os.environ[k] for k in _S3_ENV_KEYS if k in os.environ} + + +# --------------------------------------------------------------------------- +# Executor entry point (runs inside the Iris job on remote clusters) +# --------------------------------------------------------------------------- + + +def _run_executor(prefix: str, synth_data: str) -> None: + config = ExecutorMainConfig( + prefix=prefix, + executor_info_base_path=f"{prefix}/experiments", + ) + steps = create_steps("quickstart-tests", synth_data) + executor_main(config, steps=steps) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + parser = argparse.ArgumentParser(description="Run full marin pipeline on Iris") + parser.add_argument("--controller-url", required=True) + args = parser.parse_args() + + s3_base = os.environ.get("MARIN_CI_S3_PREFIX") + + if s3_base: + run_id = f"marin-itest-{uuid.uuid4().hex[:8]}" + prefix = f"{s3_base}/{run_id}" + synth_data = f"{prefix}/quickstart-data" + logger.info("Uploading test fixtures to %s", synth_data) + _upload_tree(LOCAL_SYNTH_DATA, synth_data) + cleanup = lambda: _rm_s3(prefix) # noqa: E731 + else: + prefix = tempfile.mkdtemp(prefix="iris-marin-itest-") + synth_data = str(LOCAL_SYNTH_DATA) + cleanup = lambda: shutil.rmtree(prefix, ignore_errors=True) # noqa: E731 + + os.environ["MARIN_PREFIX"] = prefix + os.environ["WANDB_MODE"] = "disabled" + os.environ["WANDB_API_KEY"] = "" + os.environ["JAX_TRACEBACK_FILTERING"] = "off" + + try: + iris_client = FrayIrisClient( + controller_address=args.controller_url, + workspace=REPO_ROOT, + ) + + if s3_base: + logger.info("Submitting executor as Iris job (S3 mode)") + env_vars = { + "MARIN_PREFIX": prefix, + "WANDB_MODE": "disabled", + "WANDB_API_KEY": "", + "JAX_TRACEBACK_FILTERING": "off", + **_s3_env_vars(), + } + + with set_current_client(iris_client): + handle = iris_client.submit( + JobRequest( + name=f"marin-itest-{uuid.uuid4().hex[:8]}", + entrypoint=Entrypoint.from_callable( + _run_executor, + args=(prefix, synth_data), + ), + resources=ResourceConfig.with_cpu(), + environment=create_environment(env_vars=env_vars), + ) + ) + handle.wait(raise_on_failure=True, stream_logs=True) + else: + logger.info("Running executor in-process (local mode)") + config = ExecutorMainConfig( + prefix=prefix, + executor_info_base_path=f"{prefix}/experiments", + ) + steps = create_steps("quickstart-tests", synth_data) + with set_current_client(iris_client): + executor_main(config, steps=steps) + + logger.info("Pipeline completed successfully") + except Exception: + logger.exception("Pipeline failed") + sys.exit(1) + finally: + cleanup() + + +if __name__ == "__main__": + main() diff --git a/tests/integration/iris/test_iris_integration.py b/tests/integration/iris/test_iris_integration.py index 7ab3e6270b..b122a3ec83 100644 --- a/tests/integration/iris/test_iris_integration.py +++ b/tests/integration/iris/test_iris_integration.py @@ -30,7 +30,6 @@ fail, sleep, register_endpoint, - validate_ports, ) logger = logging.getLogger(__name__) @@ -88,13 +87,6 @@ def test_endpoint_registration(integration_cluster): assert status.state == cluster_pb2.JOB_STATE_SUCCEEDED -def test_port_allocation(integration_cluster): - """Port allocation job succeeded.""" - job = integration_cluster.submit(validate_ports, "itest-ports", ports=["http", "grpc"]) - status = integration_cluster.wait(job, timeout=integration_cluster.job_timeout) - assert status.state == cluster_pb2.JOB_STATE_SUCCEEDED - - def test_reservation_gates_scheduling(integration_cluster): """Unsatisfiable reservation blocks scheduling; regular jobs proceed.""" with integration_cluster.launched_job( diff --git a/tests/integration/iris/test_marin_on_iris.py b/tests/integration/iris/test_marin_on_iris.py deleted file mode 100644 index d6e2361c71..0000000000 --- a/tests/integration/iris/test_marin_on_iris.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright The Marin Authors -# SPDX-License-Identifier: Apache-2.0 - -"""Full marin integration pipeline running on Iris. - -Ports the pipeline from tests/integration_test.py to dispatch through an Iris -cluster via FrayIrisClient instead of Ray. -""" - -import logging -import os -import shutil -import tempfile -from pathlib import Path - -import pytest -from fray import set_current_client -from fray.v2.iris_backend import FrayIrisClient -from marin.execution.executor import ExecutorMainConfig, executor_main -from tests.integration_test import create_steps - -logger = logging.getLogger(__name__) - -REPO_ROOT = Path(__file__).resolve().parents[3] -SYNTH_DATA = str(REPO_ROOT / "tests" / "quickstart-data") - -pytestmark = [pytest.mark.integration, pytest.mark.slow] - - -@pytest.mark.timeout(600) -def test_marin_pipeline_on_iris(integration_cluster, monkeypatch): - """Run the full marin data pipeline dispatched through Iris.""" - prefix = tempfile.mkdtemp(prefix="iris-marin-itest-") - try: - monkeypatch.setenv("MARIN_PREFIX", prefix) - monkeypatch.setenv("WANDB_MODE", "disabled") - monkeypatch.setenv("WANDB_API_KEY", "") - monkeypatch.setenv("JAX_TRACEBACK_FILTERING", "off") - - iris_client = FrayIrisClient( - controller_address=integration_cluster.url, - workspace=REPO_ROOT, - ) - - config = ExecutorMainConfig( - prefix=prefix, - executor_info_base_path=os.path.join(prefix, "experiments"), - ) - - experiment_prefix = "quickstart-tests" - steps = create_steps(experiment_prefix, SYNTH_DATA) - - with set_current_client(iris_client): - executor_main(config, steps=steps) - finally: - shutil.rmtree(prefix, ignore_errors=True) diff --git a/tests/integration_test.py b/tests/integration_test.py index 2b4bb85118..f9445ea8df 100644 --- a/tests/integration_test.py +++ b/tests/integration_test.py @@ -20,14 +20,8 @@ ) from marin.execution.step_spec import StepSpec from marin.processing.classification.consolidate import FilterConfig, FilterType, consolidate, ConsolidateConfig -from marin.processing.classification.dataset_utils import DatasetConfig from marin.processing.classification.deduplication.exact import dedup_exact_paragraph from marin.processing.classification.deduplication.fuzzy import dedup_fuzzy_document -from marin.processing.classification.fasttext.train_fasttext import ( - TrainFasttextClassifierConfig, - train, -) -from marin.processing.classification.inference import InferenceConfig, run_inference from marin.processing.tokenize import lm_data_config from marin.processing.tokenize.tokenize import TokenizeConfig, tokenize from marin.schemas.web.convert import ResiliparseConfig @@ -153,64 +147,6 @@ def create_steps(prefix: str, synth_data: str) -> list[ExecutorStep]: transform_hq_data_step = transform_hq_data_spec.as_executor_step() transform_lq_data_step = transform_lq_data_spec.as_executor_step() - # ############################################################ - # Train quality classifier - - train_quality_step = ExecutorStep( - name=os.path.join(prefix, "quality-classifier"), - fn=train, - config=TrainFasttextClassifierConfig( - datasets=[ - DatasetConfig( - input_doc_path=transform_hq_data_step, - label="hq", - sampling_rate=1.0, - ), - DatasetConfig( - input_doc_path=transform_lq_data_step, - label="lq", - sampling_rate=1.0, - ), - ], - output_path=this_output_path(), - fasttext_args={ - "lr": 0.001, - "minCount": 1, - "epoch": 25, - "wordNgrams": 2, - "dim": 50, - "thread": 1, - }, - ), - ) - - ############################################################ - # Run inference with quality classifier - - inference_hq_step = ExecutorStep( - name=os.path.join(prefix, "hq-inference"), - fn=run_inference, - config=InferenceConfig( - input_path=transform_hq_data_step, - output_path=this_output_path(), - model_name=train_quality_step, - model_type="fasttext", - attribute_name="quickstart-fasttext-quality-hq", - ), - ) - - inference_lq_step = ExecutorStep( - name=os.path.join(prefix, "lq-inference"), - fn=run_inference, - config=InferenceConfig( - input_path=transform_lq_data_step, - output_path=this_output_path(), - model_name=train_quality_step, - model_type="fasttext", - attribute_name="quickstart-fasttext-quality-lq", - ), - ) - ############################################################ # Deduplicate (StepSpec — depends on transform StepSpecs) @@ -351,9 +287,6 @@ def create_steps(prefix: str, synth_data: str) -> list[ExecutorStep]: return [ transform_hq_data_step, transform_lq_data_step, - train_quality_step, - inference_hq_step, - inference_lq_step, dedup_exact_paragraph_step, dedup_fuzzy_document_step, validate_exact_dedup_step,