Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions experiments/evals/gh_archive_structured_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""Opt-in GH Archive structured-output PPL/gap eval wiring for issue #5098."""

from __future__ import annotations

import posixpath
from dataclasses import dataclass

from marin.datakit.download.gh_archive import (
GH_ARCHIVE_OPTIONAL_EVENT_TYPES,
GH_ARCHIVE_REQUIRED_EVENT_TYPES,
make_gh_archive_step,
)
from marin.evaluation.perplexity_gap import RawTextEvaluationDataset, raw_text_dataset
from marin.execution.executor import ExecutorStep

EPIC_5005 = 5005
GH_ARCHIVE_STRUCTURED_OUTPUT_ISSUE = 5098

# Small held-out hour window to avoid broad GH Archive pulls.
GH_ARCHIVE_EVAL_START_DATE = "2024-02-01"
GH_ARCHIVE_EVAL_END_DATE = "2024-02-01"
GH_ARCHIVE_EVAL_START_HOUR = 0
GH_ARCHIVE_EVAL_END_HOUR = 1
GH_ARCHIVE_EVAL_MAX_EVENTS_PER_EVENT_TYPE = 512


@dataclass(frozen=True)
class GhArchiveStructuredOutputSlice:
event_type: str
optional: bool = False

@property
def registry_key(self) -> str:
return posixpath.join("gh_archive_structured_output", self.event_type)

@property
def raw_relative_glob(self) -> str:
return posixpath.join(self.event_type, "*.jsonl.gz")

@property
def tags(self) -> tuple[str, ...]:
return (
"gh_archive_structured_output",
f"epic:{EPIC_5005}",
f"issue:{GH_ARCHIVE_STRUCTURED_OUTPUT_ISSUE}",
f"event_type:{self.event_type}",
)


GH_ARCHIVE_STRUCTURED_OUTPUT_SLICES: tuple[GhArchiveStructuredOutputSlice, ...] = (
*(GhArchiveStructuredOutputSlice(event_type=event_type) for event_type in GH_ARCHIVE_REQUIRED_EVENT_TYPES),
*(
GhArchiveStructuredOutputSlice(event_type=event_type, optional=True)
for event_type in GH_ARCHIVE_OPTIONAL_EVENT_TYPES
),
)

gh_archive_structured_output_eval = make_gh_archive_step(
name="raw/gh_archive/structured_output_eval_2024_02_01_h00_01",
start_date=GH_ARCHIVE_EVAL_START_DATE,
end_date=GH_ARCHIVE_EVAL_END_DATE,
start_hour=GH_ARCHIVE_EVAL_START_HOUR,
end_hour=GH_ARCHIVE_EVAL_END_HOUR,
event_types=tuple(slice_.event_type for slice_ in GH_ARCHIVE_STRUCTURED_OUTPUT_SLICES),
max_events_per_event_type=GH_ARCHIVE_EVAL_MAX_EVENTS_PER_EVENT_TYPE,
)


def gh_archive_structured_output_raw_validation_sets(
*,
raw_root: str | None = None,
gh_archive_raw: ExecutorStep | None = None,
include_optional_event_types: bool = True,
) -> dict[str, RawTextEvaluationDataset]:
"""Materialize GH Archive structured-output slices as opt-in raw validation datasets."""
if raw_root is not None and gh_archive_raw is not None:
raise ValueError("Provide either raw_root or gh_archive_raw, not both.")

if raw_root is None and gh_archive_raw is None:
gh_archive_raw = gh_archive_structured_output_eval

datasets: dict[str, RawTextEvaluationDataset] = {}
for slice_ in GH_ARCHIVE_STRUCTURED_OUTPUT_SLICES:
if slice_.optional and not include_optional_event_types:
continue

if raw_root is not None:
source = posixpath.join(raw_root, slice_.raw_relative_glob)
else:
assert gh_archive_raw is not None
source = gh_archive_raw.cd(slice_.raw_relative_glob)

datasets[slice_.registry_key] = raw_text_dataset(source, tags=slice_.tags)
return datasets
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

from fray.v2.types import ResourceConfig

from experiments.evals.gh_archive_structured_output import gh_archive_structured_output_raw_validation_sets
from marin.evaluation.perplexity_gap import GapFinderModelConfig, default_model_perplexity_gap
from marin.execution.executor import executor_main

RESOURCE_CONFIG = ResourceConfig.with_tpu("v5p-8", regions=["us-central1"])
MAX_DOCS_PER_DATASET = 256
MAX_DOC_BYTES = 32_768

DATASETS = gh_archive_structured_output_raw_validation_sets()

MARIN_MODEL = GapFinderModelConfig(
checkpoint_path="marin-community/marin-8b-base",
checkpoint_is_hf=True,
tokenizer="meta-llama/Llama-3.1-8B",
)

MARIN_VS_LLAMA = default_model_perplexity_gap(
name="gh-archive-structured-output-marin-8b-base-vs-llama-3.1-8b-base-doccap256",
model_a=MARIN_MODEL,
model_b=GapFinderModelConfig(
checkpoint_path="meta-llama/Llama-3.1-8B",
checkpoint_is_hf=True,
tokenizer="meta-llama/Llama-3.1-8B",
),
datasets=DATASETS,
resource_config=RESOURCE_CONFIG,
per_device_batch_size=4,
max_eval_length=4096,
max_docs_per_dataset=MAX_DOCS_PER_DATASET,
max_doc_bytes=MAX_DOC_BYTES,
wandb_tags=[
"eval=perplexity-gap",
"bundle=gh_archive_structured_output",
"epic=5005",
"issue=5098",
f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}",
],
)

MARIN_VS_QWEN3 = default_model_perplexity_gap(
name="gh-archive-structured-output-marin-8b-base-vs-qwen3-8b-base-doccap256",
model_a=MARIN_MODEL,
model_b=GapFinderModelConfig(
checkpoint_path="Qwen/Qwen3-8B-Base",
checkpoint_is_hf=True,
tokenizer="Qwen/Qwen3-8B",
),
datasets=DATASETS,
resource_config=RESOURCE_CONFIG,
per_device_batch_size=4,
max_eval_length=4096,
max_docs_per_dataset=MAX_DOCS_PER_DATASET,
max_doc_bytes=MAX_DOC_BYTES,
wandb_tags=[
"eval=perplexity-gap",
"bundle=gh_archive_structured_output",
"epic=5005",
"issue=5098",
f"max_docs_per_dataset={MAX_DOCS_PER_DATASET}",
],
)


if __name__ == "__main__":
executor_main(
[MARIN_VS_LLAMA, MARIN_VS_QWEN3],
description="Run perplexity-gap reports on GH Archive structured-output eval slices.",
)
Loading
Loading