diff --git a/benchmarks/arteval_bench/README.md b/benchmarks/arteval_bench/README.md index 9faf9a25..bfc30f77 100644 --- a/benchmarks/arteval_bench/README.md +++ b/benchmarks/arteval_bench/README.md @@ -34,68 +34,127 @@ Using WASABI's [agent evaluator](data/benchmark/sosp24_wasabi/wasabi/_agent_eval 1. An `_agent_eval/` package which contains all benchmark-specific code and does *not* modify your original artifact logic. -2. One oracle module per stage, implemented in four distinct Python files each checking one of the four canonical stages of artifact evaluation. A typical oracle module looks as follows (simplified): +2. One oracle module per stage. In this benchmark, each stage is typically implemented as a **derived oracle class** that overrides `requirements()` and returns an ordered list of programmatic checks (requirements). The base oracle handles running requirements, producing a structured report, printing a PASS/FAIL summary, and returning `True`/`False` from `run(verbose=...)`. + + A typical `_agent_eval/` layout looks like: + + ```text + _agent_eval/ + ├── main.py + ├── oracle_env_setup.py + ├── oracle_build_install.py + ├── oracle_prep_benchmark.py + ├── oracle_run_experiments.py + └── refs/ + ├── datasets.ref.json + └── results.ref.json + ``` + + The `refs/` directory stores machine-checkable ground truth (e.g., dataset manifests/checksums, expected metric tables, or summaries of deterministic outputs) used by benchmark-prep and experiment-runs checks. + + Here is a simplified environment setup oracle (one dependency/version requirement): + + ```python + # _agent_eval/oracle_env_setup.py + import sys + from collections.abc import Sequence + + from evaluator.oracle_env_setup_primitives import ( + DependencyVersionRequirement, + OracleEnvSetupBase, + VersionCompare, + ) + + class OracleEnvSetup(OracleEnvSetupBase): + def __init__(self, *, config, logger): + super().__init__(logger=logger) + self._config = config + + def requirements(self) -> Sequence[DependencyVersionRequirement]: + return ( + DependencyVersionRequirement( + name="python_version", + cmd=(sys.executable, "--version"), + required_version=(3, 10, 0), + compare=VersionCompare.GEQ, + timeout_seconds=5.0, + ), + ) + ``` + + Also, note that each oracle should be: + - Non-interactive, meaning not expecting input or prompt interactions. + - Idempotent, meaning safe to run multiple times without side-effects. + - Time-bounded, meaning every command has a timeout. + - Binary, meaning it returns pass/fail (as `True`/`False`) for the stage. + + For more details, check out this [how-to guide](src/evaluator/HOWTO.md) + +1. A single `main.py` orchestrator, the entrypoint used by ArtEvalBench, which constructs an `EntryConfig`, invokes the four oracles in order, and returns an overall score (an integer between 0 and 4): + ```python - # _agent_eval/env_setup.py - import subprocess + # _agent_eval/main.py + import os from pathlib import Path - def check() -> bool: - # Example: verify virtualenv exists - if not Path("venv").exists(): - print("Missing venv/ directory") - return False - - # Example: verify Python version inside the venv - proc = subprocess.run( - ["venv/bin/python", "--version"], - capture_output=True, - text=True, + from evaluator.utils import EntryConfig, LoggerConfig, get_logger, record_result + + from oracle_env_setup import OracleEnvSetup + from oracle_build_install import OracleBuildInstall + from oracle_prep_benchmark import OraclePrepBenchmark + from oracle_run_experiments import OracleRunExperiments + + CONFIG = EntryConfig( + name="my-artifact", + home_dir=Path.home() / "artevalbench" / "my-artifact", + repository_paths={ + "my-artifact": Path.home() / "artevalbench" / "my-artifact" / "repo", + }, + results_paths={ + "results": Path.home() / "artevalbench" / "my-artifact" / "repo" / "outputs" / "results.json", + }, + ground_truth_paths={ + "datasets": Path.home() / "artevalbench" / "my-artifact" / "_agent_eval" / "refs" / "datasets.ref.json", + "results": Path.home() / "artevalbench" / "my-artifact" / "_agent_eval" / "refs" / "results.ref.json", + }, + similarity_ratio=0.75, + ) + + def main(argv: list[str]) -> int: + verbose = "--verbose" in argv + logger = get_logger( + LoggerConfig(root_name=os.environ.get("EVAL_LOGGER_NAME", "ARTEVAL-EVAL")) + ) + + results: dict[str, int] = {} + score = 0 + + score += record_result( + results, "env_setup", + OracleEnvSetup(config=CONFIG, logger=logger).run(verbose=verbose), + ) + score += record_result( + results, "build_install", + OracleBuildInstall(config=CONFIG, logger=logger).run(verbose=verbose), + ) + score += record_result( + results, "prep_benchmark", + OraclePrepBenchmark(config=CONFIG, logger=logger).run(verbose=verbose), ) - print(proc.stdout.strip()) - return proc.returncode == 0 and proc.stdout.startswith("Python 3.10") - ``` - Also, note that each oracle should be: - - Non-interactive, meaning not expecting input or prompt interactions. - - Idempotent, meaning safe to run multiple times without side-effects. - - It returns `True` or `False` based on the validation outcome and prints a brief diagnostic message. - -3. A single `main.py` orchestrator, the entrypoint used by ArtEvalBench, which invokes the four oracle modules, runs them in order, and returns an overall score (an integer between 0 and 4): - ```python - # _agent_eval/main.py - from . import env_setup, build_install, prep_benchmark, run_experiments - - def main() -> int: - score = 0 - stages = [ - ("env_setup", env_setup.check), - ("build_install", build_install.check), - ("prep_benchmark", prep_benchmark.check), - ("run_experiments", run_experiments.check), - ] - - for name, check in stages: - try: - ok = bool(check()) - except Exception as e: - print(f"[{name}] FAILED with exception: {e}") - ok = False - - if ok: - print(f"[{name}] PASSED") - score += 1 - else: - print(f"[{name}] FAILED") - - print(f"FINAL_SCORE {score}/4") - return score - - if __name__ == "__main__": - raise SystemExit(main()) - ``` - - Note that the `ArtEvalBench` framework will invoke `main.py` to run the oracles in order, compute the agent's score for this particular artifact, and store it into a JSON file that aggregates these outcomes for the entire benchmark. + score += record_result( + results, "run_experiments", + OracleRunExperiments(config=CONFIG, logger=logger).run(verbose=verbose), + ) + + logger.info("Stage scores: %s", results) + logger.info("FINAL_SCORE %d/4", score) + return score + if __name__ == "__main__": + raise SystemExit(main([])) + ``` + + Note that the `ArtEvalBench` framework will invoke `main.py` to run the oracles in order, compute the agent's score for this particular artifact, and store it into a JSON file that aggregates these outcomes for the entire benchmark. ## Benchmark Setup @@ -105,10 +164,10 @@ To run the benchmark: 1. Execute the `run.sh` script with your model: - ```sh - ./run.sh - # Example: ./run.sh claude-sonnet-4-5-20250929 - ``` +```sh +./run.sh +# Example: ./run.sh claude-sonnet-4-5-20250929 +``` 2. Configure your LLM endpoint in `env.toml`: * For Azure/OpenAI models: Set `AZURE_API_KEY`, `AZURE_API_BASE`, `AZURE_API_VERSION` @@ -117,7 +176,6 @@ To run the benchmark: 3. Results will be saved to `outputs/` with timestamp and model information - #### » Supported Agents The benchmark supports multiple AI agents: diff --git a/benchmarks/arteval_bench/src/evaluator/README.md b/benchmarks/arteval_bench/src/evaluator/HOWTO.md similarity index 94% rename from benchmarks/arteval_bench/src/evaluator/README.md rename to benchmarks/arteval_bench/src/evaluator/HOWTO.md index d887df22..8aa160ae 100644 --- a/benchmarks/arteval_bench/src/evaluator/README.md +++ b/benchmarks/arteval_bench/src/evaluator/HOWTO.md @@ -1,15 +1,15 @@ # Agent Evaluator Primitives -This bundle provides primitives for four oracles that verify if an AI agent can succesfully evaluating a set of artifacts, namely setting up, building code, downloading datasets and runing experiments end-to-end. Each oracle corresponds to one stage of the artifact evaluation (AE) process and encodes minimal, objective, and programatically verifiable success criteria. Oracles are designed to be idempotent (safe to run multiple times), non-interactive (no blocking events like I/O actions or manual intervention), and produce a binary outcome (either "pass" or "fail"). +This utility provides building blocks for four validation oracles that check whether an AI agent can evaluate artifacts end-to-end: set up the environment, build or install the main code modules, download and prepare datasets/benchmarks, and run the experiments. Each oracle matches one cannonical stage of the artifact evaluation (AE) process and defines simple, objective checks that can be verified programatically. The oracles are idempotent (i.e., safe to run multiple times), non-interactive (i.e., no prompts or manual steps), and return a clear result (i.e. "pass" or "fail"). -The oracles verify four canonical stages of the AE process: +The four canonical stages of the AE process that these oracles validate are as follows: 1. Environment setup: check required tools/dependencies exist and meet version constraints; confirm key environment variables and required files/directories are present. 2. Artifact build: run build/install commands and fail if they do not complete successfully. 3. Benchmark preparation: check datasets/benchmarks/tools are present and usable; optionally run quick commands and check for expected output signatures. 4. Experiment runs: compare observed to reference values using similarity or elementwise checks within cutomizable tolerance thresholds. -Each artifact includes a self-contained oracles in a `_agent_eval/` directory. These scripts extend the base primitives descrived above to create specialized oracles that assert success criteria at each AE stage. +Each artifact includes a self-contained oracles in a `_agent_eval/` directory. This extra code extends the base primitives described above to create specialized oracles that assert success criteria at each AE stage. ## Implementing agent evaluators diff --git a/benchmarks/arteval_bench/src/evaluator/oracle_artifact_build_primitives.py b/benchmarks/arteval_bench/src/evaluator/oracle_artifact_build_primitives.py index 2f0f4a1d..0fb17ee2 100644 --- a/benchmarks/arteval_bench/src/evaluator/oracle_artifact_build_primitives.py +++ b/benchmarks/arteval_bench/src/evaluator/oracle_artifact_build_primitives.py @@ -26,7 +26,6 @@ from evaluator import utils - # ------------------------------------------------------------------------------ # Helper functions # ------------------------------------------------------------------------------ @@ -69,7 +68,6 @@ class BuildContext: logger: logging.Logger - @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class BuildCommandRequirement(utils.BaseRequirement): """Runs a build command within a working directory. @@ -78,14 +76,14 @@ class BuildCommandRequirement(utils.BaseRequirement): name: Human-readable requirement name for logs and reports. optional: Whether failure should be treated as a warning instead of an error. cwd: Base working directory. - command: Command argv to execute. + cmd: Command argv to execute. relative_workdir: Optional subdirectory within cwd used as the actual workdir. timeout_seconds: Timeout for the command, in seconds. env_overrides: Environment variables to override for the subprocess. """ cwd: pathlib.Path - command: Sequence[str] + cmd: Sequence[str] relative_workdir: pathlib.Path | None = None timeout_seconds: float = 60.0 env_overrides: Mapping[str, str] = dataclasses.field(default_factory=dict) @@ -93,36 +91,41 @@ class BuildCommandRequirement(utils.BaseRequirement): def __post_init__(self) -> None: object.__setattr__(self, "cwd", utils.to_path(self.cwd)) if self.relative_workdir is not None: - object.__setattr__(self, "relative_workdir", utils.to_path(self.relative_workdir)) + object.__setattr__(self, "relative_workdir", + utils.to_path(self.relative_workdir)) - if isinstance(self.command, (str, bytes)): - raise TypeError(f"{self.name}: command must be a sequence of argv strings, not a single string/bytes") + if isinstance(self.cmd, (str, bytes)): + raise TypeError( + f"{self.name}: command must be a sequence of argv strings, not a single string/bytes" + ) - if not self.command: + if not self.cmd: raise ValueError(f"{self.name}: command must be non-empty") - bad = [a for a in self.command if not isinstance(a, str) or a == ""] + bad = [a for a in self.cmd if not isinstance(a, str) or a == ""] if bad: - raise TypeError(f"{self.name}: all command argv entries must be non-empty str; bad entries: {bad!r}") + raise TypeError( + f"{self.name}: all command argv entries must be non-empty str; bad entries: {bad!r}" + ) if self.timeout_seconds <= 0: raise ValueError(f"{self.name}: timeout (seconds) must be > 0") - # NOTE: Be tolerant to callers passing non-str values (e.g., Path/int) by - # normalizing everything to str, since subprocess env requires str->str. env_dict_raw = dict(self.env_overrides) env_dict: dict[str, str] = {} for k, v in env_dict_raw.items(): - # Preserve previous strictness for obviously broken keys. if k is None or k == "": - raise TypeError(f"{self.name}: env_overrides contains an empty env var name: {k!r}") + raise TypeError( + f"{self.name}: env_overrides contains an empty env var name: {k!r}") env_dict[str(k)] = str(v) - # Prevent obvious "not relative" cases early. - if self.relative_workdir is not None and self.relative_workdir.is_absolute(): - raise ValueError(f"{self.name}: relative_workdir must be a relative path, got: {self.relative_workdir}") + if self.relative_workdir is not None and self.relative_workdir.is_absolute( + ): + raise ValueError( + f"{self.name}: relative_workdir must be a relative path, got: {self.relative_workdir}" + ) - object.__setattr__(self, "command", tuple(self.command)) + object.__setattr__(self, "command", tuple(self.cmd)) object.__setattr__(self, "env_overrides", types.MappingProxyType(env_dict)) @staticmethod @@ -134,10 +137,6 @@ def _is_within_base_dir(*, base: pathlib.Path, target: pathlib.Path) -> bool: try: base_real = base.resolve(strict=True) target_real = target.resolve(strict=True) - - # NOTE: Prefer pathlib semantics over string commonpath to avoid - # platform corner cases (drives, separators). This also avoids false - # positives from simple string-prefix checks. try: target_real.relative_to(base_real) return True @@ -148,41 +147,37 @@ def _is_within_base_dir(*, base: pathlib.Path, target: pathlib.Path) -> bool: @staticmethod def _coerce_text(x: object) -> str: - # NOTE: utils.decode_text may not accept str in some codebases. This helper - # safely handles bytes/str/None and keeps the old behavior stable. if x is None: return "" if isinstance(x, str): return x if isinstance(x, (bytes, bytearray, memoryview)): return utils.decode_text(bytes(x)) - # Fallback: best-effort stringification return str(x) def _run_with_limited_output( - self, - *, - workdir: pathlib.Path, - env: Mapping[str, str], + self, + *, + workdir: pathlib.Path, + env: Mapping[str, str], ) -> tuple[int | None, str, str, bool]: """Run process while limiting captured output to avoid unbounded memory. Returns (returncode, stdout, stderr, timed_out). """ - # NOTE: We run with stdout/stderr pipes in *binary* mode and decode ourselves. - # This avoids UnicodeDecodeError surprises while reading incrementally. try: proc = subprocess.Popen( - self.command, - cwd=workdir, - env=dict(env), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=False, + self.cmd, + cwd=workdir, + env=dict(env), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=False, ) except OSError as exc: - # Let caller map this to CheckResult.failure, preserving existing behavior. - raise + cmd_display = " ".join(self.cmd) + raise OSError( + f"failed to run command: {cmd_display} (cwd={workdir})") from exc assert proc.stdout is not None assert proc.stderr is not None @@ -191,8 +186,6 @@ def _run_with_limited_output( sel.register(proc.stdout, selectors.EVENT_READ, data="stdout") sel.register(proc.stderr, selectors.EVENT_READ, data="stderr") - # NOTE: Cap memory usage by storing only up to a fixed number of bytes. - # We use 4x char cap as a conservative UTF-8 upper bound. byte_cap = int(utils.DEFAULT_MAX_CAPTURE_CHARS) * 4 stdout_buf = bytearray() @@ -202,12 +195,10 @@ def _run_with_limited_output( timed_out = False def _read_chunk(stream) -> bytes: - # Prefer read1 when available for buffered streams. if hasattr(stream, "read1"): - return stream.read1(8192) # type: ignore[attr-defined] + return stream.read1(8192) return stream.read(8192) - # Read incrementally from both pipes until closed or timeout. while sel.get_map(): remaining = deadline - time.monotonic() if remaining <= 0: @@ -233,7 +224,6 @@ def _read_chunk(stream) -> bytes: if len(stdout_buf) < byte_cap: take = min(len(chunk), byte_cap - len(stdout_buf)) stdout_buf.extend(chunk[:take]) - # NOTE: Discard remainder to cap memory; continue draining to avoid deadlock. else: if len(stderr_buf) < byte_cap: take = min(len(chunk), byte_cap - len(stderr_buf)) @@ -245,8 +235,6 @@ def _read_chunk(stream) -> bytes: except Exception: pass - # Best-effort drain for a short period so we capture some tail output - # without risking hangs. drain_deadline = time.monotonic() + 1.0 while sel.get_map() and time.monotonic() < drain_deadline: events = sel.select(timeout=0.1) @@ -272,25 +260,26 @@ def _read_chunk(stream) -> bytes: take = min(len(chunk), byte_cap - len(stderr_buf)) stderr_buf.extend(chunk[:take]) - # Reap the process to avoid zombies. try: proc.wait(timeout=5.0) except Exception: pass - stdout = utils.truncate_text(self._coerce_text(stdout_buf), utils.DEFAULT_MAX_CAPTURE_CHARS) - stderr = utils.truncate_text(self._coerce_text(stderr_buf), utils.DEFAULT_MAX_CAPTURE_CHARS) + stdout = utils.truncate_text(self._coerce_text(stdout_buf), + utils.DEFAULT_MAX_CAPTURE_CHARS) + stderr = utils.truncate_text(self._coerce_text(stderr_buf), + utils.DEFAULT_MAX_CAPTURE_CHARS) return None, stdout, stderr, True - # Process finished or pipes closed; reap returncode. try: rc = proc.wait(timeout=5.0) except Exception: - # If something odd happens, keep behavior conservative. rc = proc.returncode - stdout = utils.truncate_text(self._coerce_text(stdout_buf), utils.DEFAULT_MAX_CAPTURE_CHARS) - stderr = utils.truncate_text(self._coerce_text(stderr_buf), utils.DEFAULT_MAX_CAPTURE_CHARS) + stdout = utils.truncate_text(self._coerce_text(stdout_buf), + utils.DEFAULT_MAX_CAPTURE_CHARS) + stderr = utils.truncate_text(self._coerce_text(stderr_buf), + utils.DEFAULT_MAX_CAPTURE_CHARS) return rc, stdout, stderr, False def check(self, ctx: BuildContext) -> utils.CheckResult: @@ -307,11 +296,11 @@ def check(self, ctx: BuildContext) -> utils.CheckResult: if error is not None: return utils.CheckResult.failure(error, cwd=workdir) - # Walidate cwd and prevent ``espacping'' (e.g., ../ or symlinks) + # Validate cwd and prevent ``espacping'' (e.g., ../ or symlinks) if not self._is_within_base_dir(base=self.cwd, target=workdir): return utils.CheckResult.failure( - f"working directory escapes base cwd: base={self.cwd} workdir={workdir}", - cwd=workdir, + f"working directory escapes base cwd: base={self.cwd} workdir={workdir}", + cwd=workdir, ) env = os.environ.copy() @@ -319,31 +308,29 @@ def check(self, ctx: BuildContext) -> utils.CheckResult: env.update(self.env_overrides) try: - # NOTE: Avoid capture_output=True because it can buffer unbounded output - # and spike memory; we capture incrementally with a fixed cap. returncode, stdout, stderr, timed_out = self._run_with_limited_output( - workdir=workdir, - env=env, + workdir=workdir, + env=env, ) except OSError as exc: return utils.CheckResult.failure( - f"failed to run command: {exc}", - stdout="", - stderr=str(exc), - returncode=None, - timed_out=False, - cwd=workdir, + f"failed to run command: {exc}", + stdout="", + stderr=str(exc), + returncode=None, + timed_out=False, + cwd=workdir, ) if timed_out: # Handle case when stdout/stderr is None return utils.CheckResult.failure( - f"command timed out after {self.timeout_seconds}s", - stdout=stdout, - stderr=stderr, - returncode=None, - timed_out=True, - cwd=workdir, + f"command timed out after {self.timeout_seconds}s", + stdout=stdout, + stderr=stderr, + returncode=None, + timed_out=True, + cwd=workdir, ) if returncode != 0: @@ -352,19 +339,19 @@ def check(self, ctx: BuildContext) -> utils.CheckResult: if detail: msg = f"{msg}: {detail}" return utils.CheckResult.failure( - msg, + msg, + stdout=stdout, + stderr=stderr, + returncode=returncode, + timed_out=False, + cwd=workdir, + ) + + return utils.CheckResult.success( stdout=stdout, stderr=stderr, returncode=returncode, - timed_out=False, cwd=workdir, - ) - - return utils.CheckResult.success( - stdout=stdout, - stderr=stderr, - returncode=returncode, - cwd=workdir, ) diff --git a/benchmarks/arteval_bench/src/evaluator/oracle_benchmark_prep_primitives.py b/benchmarks/arteval_bench/src/evaluator/oracle_benchmark_prep_primitives.py index b3b7335a..b079920b 100644 --- a/benchmarks/arteval_bench/src/evaluator/oracle_benchmark_prep_primitives.py +++ b/benchmarks/arteval_bench/src/evaluator/oracle_benchmark_prep_primitives.py @@ -30,15 +30,12 @@ from evaluator import utils - # ------------------------------------------------------------------------------ # Basic types and constants # ------------------------------------------------------------------------------ - _CommandT = str | Sequence[str] - # ------------------------------------------------------------------------------ # Helper functions # ------------------------------------------------------------------------------ @@ -48,7 +45,6 @@ def _format_command(cmd: _CommandT, *, use_shell: bool) -> str: """Returns a readable representation of command suitable for error messages.""" if isinstance(cmd, str): return cmd if use_shell else shlex.quote(cmd) - # NOTE: quote() used for readability display only return " ".join(shlex.quote(str(arg)) for arg in cmd) @@ -108,7 +104,8 @@ def _run_command( max_chars = utils.DEFAULT_MAX_CAPTURE_CHARS suffix = "..." - def _append_bounded(buf: list[str], cur_len: int, text: str) -> tuple[int, bool]: + def _append_bounded(buf: list[str], cur_len: int, + text: str) -> tuple[int, bool]: """Append up to max_chars, return (new_len, overflowed).""" if cur_len >= max_chars: return cur_len, True @@ -126,7 +123,7 @@ def _append_bounded(buf: list[str], cur_len: int, text: str) -> tuple[int, bool] stdout_tail = "" stderr_tail = "" - stderr_head = "" + stderr_head = "" encoding = locale.getpreferredencoding(False) or "utf-8" stdout_dec = codecs.getincrementaldecoder(encoding)(errors="replace") diff --git a/benchmarks/arteval_bench/src/evaluator/oracle_env_setup_primitives.py b/benchmarks/arteval_bench/src/evaluator/oracle_env_setup_primitives.py index 07a2dc62..aa0dea3d 100644 --- a/benchmarks/arteval_bench/src/evaluator/oracle_env_setup_primitives.py +++ b/benchmarks/arteval_bench/src/evaluator/oracle_env_setup_primitives.py @@ -26,14 +26,13 @@ from evaluator import utils - # ------------------------------------------------------------------------------ # Basic types and constants # ------------------------------------------------------------------------------ - SemanticVersion = tuple[int, int, int] + @enum.unique class VersionCompare(enum.Enum): """Comparison operator for validating a discovered version.""" @@ -102,7 +101,7 @@ class DependencyVersionRequirement(utils.BaseRequirement): Attributes: name: Human-readable requirement name for logs and reports. optional: Whether failure should be treated as a warning instead of an error. - command: Command argv used to query a version (e.g., ["python", "--version"]). + cmd: Command argv used to query a version (e.g., ["python", "--version"]). required_version: Minimum/required semantic version as (major, minor, patch). compare: Comparison operator to apply against required_version. version_regex: Optional regex with a capturing group for the version token. @@ -124,7 +123,7 @@ def __post_init__(self) -> None: raise ValueError(f"{self.name}: command must be non-empty") if self.timeout_seconds <= 0: raise ValueError(f"{self.name}: timeout_seconds must be > 0") - object.__setattr__(self, "command", tuple(self.cmd)) + object.__setattr__(self, "cmd", tuple(self.cmd)) if self.version_regex is not None: pattern = re.compile(self.version_regex, flags=re.IGNORECASE) @@ -326,7 +325,7 @@ def check(self) -> utils.CheckResult: if self.path.is_dir(): return utils.CheckResult.success() return utils.CheckResult.failure(f"expected directory: {self.path}") - + class OracleEnvSetupBase(abc.ABC): """Base class for an environment setup oracle. diff --git a/benchmarks/arteval_bench/src/evaluator/oracle_experiment_runs_primitives.py b/benchmarks/arteval_bench/src/evaluator/oracle_experiment_runs_primitives.py index 16b49534..fcea5397 100644 --- a/benchmarks/arteval_bench/src/evaluator/oracle_experiment_runs_primitives.py +++ b/benchmarks/arteval_bench/src/evaluator/oracle_experiment_runs_primitives.py @@ -22,18 +22,21 @@ import math import typing +from collections import Counter from collections.abc import Callable, Sequence from evaluator import utils - # --------------------------------------------------------------------------- # Basic types and constants # --------------------------------------------------------------------------- - _CmpT = typing.TypeVar("_CmpT") +# Numerical tolerance for treating very small norms/variances as zero. +_EPS = 1e-12 + + @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class Compared(typing.Generic[_CmpT]): """A single observed-vs-reference comparison record. @@ -79,22 +82,20 @@ def _require_all_finite(values: Sequence[float], *, label: str) -> None: def _jaccard_set_similarity(left: Sequence[float], right: Sequence[float]) -> float: """Jaccard similarity treating inputs as sets (order/duplicates ignored).""" + _require_all_finite(left, label="jaccard_set_similarity.left") + _require_all_finite(right, label="jaccard_set_similarity.right") - def _normalize(x: float) -> object: - if _is_nan(x): - return ("nan",) - return x - - left_norm = [_normalize(x) for x in left] - right_norm = [_normalize(x) for x in right] - - a = set(left_norm) - b = set(right_norm) + a = set(left) + b = set(right) - if len(a) != len(left_norm): - raise ValueError("jaccard_set_similarity: left input contains duplicates (multiset not allowed)") - if len(b) != len(right_norm): - raise ValueError("jaccard_set_similarity: right input contains duplicates (multiset not allowed)") + if len(a) != len(left): + raise ValueError( + "jaccard_set_similarity: left input contains duplicates; " + "use jaccard_multiset_similarity if duplicates are meaningful") + if len(b) != len(right): + raise ValueError( + "jaccard_set_similarity: right input contains duplicates; " + "use jaccard_multiset_similarity if duplicates are meaningful") union = a | b if not union: @@ -102,6 +103,24 @@ def _normalize(x: float) -> object: return len(a & b) / len(union) +def _jaccard_multiset_similarity(left: Sequence[float], + right: Sequence[float]) -> float: + """Jaccard similarity treating inputs as multisets (duplicates preserved).""" + _require_all_finite(left, label="jaccard_multiset_similarity.left") + _require_all_finite(right, label="jaccard_multiset_similarity.right") + + a = Counter(left) + b = Counter(right) + keys = set(a) | set(b) + + den = sum(max(a[k], b[k]) for k in keys) + if den == 0: + return 1.0 + + num = sum(min(a[k], b[k]) for k in keys) + return num / den + + def _dot_product(left: Sequence[float], right: Sequence[float]) -> float: """Dot product (unbounded). Requires equal lengths and finite inputs.""" _require_equal_lengths(left, right, label="dot_product") @@ -129,9 +148,9 @@ def _cosine_similarity(left: Sequence[float], right: Sequence[float]) -> float: norm_left += a * a norm_right += b * b - if norm_left == 0.0 and norm_right == 0.0: + if norm_left <= _EPS and norm_right <= _EPS: return 1.0 - if norm_left == 0.0 or norm_right == 0.0: + if norm_left <= _EPS or norm_right <= _EPS: return 0.0 return dot / (math.sqrt(norm_left) * math.sqrt(norm_right)) @@ -169,9 +188,9 @@ def _pearson_similarity(left: Sequence[float], right: Sequence[float]) -> float: var_left += da * da var_right += db * db - if var_left == 0.0 and var_right == 0.0: + if var_left <= _EPS and var_right <= _EPS: return 1.0 if list(left) == list(right) else 0.0 - if var_left == 0.0 or var_right == 0.0: + if var_left <= _EPS or var_right <= _EPS: return 0.0 return cov / (math.sqrt(var_left) * math.sqrt(var_right)) @@ -210,8 +229,9 @@ def _min_max_similarity(left: Sequence[float], right: Sequence[float]) -> float: def _numbers_equal(a: float, b: float, *, nan_equal: bool) -> bool: - if nan_equal and _is_nan(a) and _is_nan(b): - return True + del nan_equal # Non-finite inputs are rejected by policy. + if not math.isfinite(a) or not math.isfinite(b): + raise ValueError(f"numbers_equal: non-finite input: a={a!r}, b={b!r}") return a == b @@ -225,11 +245,10 @@ def _default_numeric_similarity(a: float, b: float, *, - NaN vs NaN => 1.0, NaN vs non-NaN => 0.0 - +inf vs +inf or -inf vs -inf => 1.0, otherwise 0.0 """ - if _is_nan(a) or _is_nan(b): - return 1.0 if (_is_nan(a) and _is_nan(b)) else 0.0 - - if math.isinf(a) or math.isinf(b): - return 1.0 if a == b else 0.0 + # Uniform policy: reject non-finite inputs across the entire file. + if not math.isfinite(a) or not math.isfinite(b): + raise ValueError( + f"default_numeric_similarity: non-finite input: a={a!r}, b={b!r}") denom = max(abs(a), abs(b), abs_epsilon) score = 1.0 - (abs(a - b) / denom) @@ -241,7 +260,6 @@ def _default_numeric_similarity(a: float, b: float, *, return score - def _elementwise_similarity_scores( observed: Sequence[float], reference: Sequence[float], @@ -252,6 +270,9 @@ def _elementwise_similarity_scores( _require_equal_lengths(observed, reference, label="elementwise_similarity_scores") + _require_all_finite(observed, label="elementwise_similarity_scores.observed") + _require_all_finite(reference, + label="elementwise_similarity_scores.reference") if abs_epsilon <= 0: raise ValueError(f"elementwise_similarity_scores: abs_epsilon must be > 0") @@ -273,6 +294,8 @@ def _elementwise_equal( nan_equal: bool, ) -> list[Compared[bool]]: _require_equal_lengths(observed, reference, label="elementwise_equal") + _require_all_finite(observed, label="elementwise_equal.observed") + _require_all_finite(reference, label="elementwise_equal.reference") out: list[Compared[bool]] = [] for a, b in zip(observed, reference, strict=True): out.append( @@ -334,6 +357,7 @@ class SimilarityMetric(enum.Enum): """List-level metric identifier for computing a single similarity score.""" JACCARD_SET = "jaccard_set" + JACCARD_MULTISET = "jaccard_multiset" DOT_PRODUCT = "dot_product" COSINE = "cosine" PEARSON = "pearson" @@ -351,6 +375,8 @@ def compute( ) -> float: if metric == SimilarityMetric.JACCARD_SET: return _jaccard_set_similarity(left, right) + if metric == SimilarityMetric.JACCARD_MULTISET: + return _jaccard_multiset_similarity(left, right) if metric == SimilarityMetric.DOT_PRODUCT: return _dot_product(left, right) if metric == SimilarityMetric.COSINE: @@ -407,6 +433,10 @@ def similarity_threshold( if not math.isfinite(threshold): raise ValueError( f"similarity_threshold: threshold must be finite, got {threshold!r}") + if not (0.0 <= threshold <= 1.0): + raise ValueError( + f"similarity_threshold: threshold must be in [0, 1], got {threshold!r}" + ) out: list[Compared[bool]] = [] for s in scores: @@ -417,7 +447,7 @@ def similarity_threshold( return out -@dataclasses.dataclass(...) +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class ExperimentRunsContext: """Context passed to experiment-run requirements. @@ -449,6 +479,20 @@ class ListSimilarityRequirement(utils.BaseRequirement): def __post_init__(self) -> None: if not math.isfinite(self.min_similarity): raise ValueError(f"{self.name}: min_similarity must be finite") + + if self.metric in (SimilarityMetric.JACCARD_SET, + SimilarityMetric.JACCARD_MULTISET, + SimilarityMetric.MIN_MAX): + if not (0.0 <= self.min_similarity <= 1.0): + raise ValueError( + f"{self.name}: {self.metric.value} min_similarity must be in [0, 1], " + f"got {self.min_similarity!r}") + if self.metric in (SimilarityMetric.COSINE, SimilarityMetric.PEARSON): + if not (-1.0 <= self.min_similarity <= 1.0): + raise ValueError( + f"{self.name}: {self.metric.value} min_similarity must be in [-1, 1], " + f"got {self.min_similarity!r}") + object.__setattr__(self, "observed", tuple(self.observed)) object.__setattr__(self, "reference", tuple(self.reference)) @@ -458,7 +502,7 @@ def check(self, ctx: ExperimentRunsContext) -> utils.CheckResult: score = SimilarityMetrics.compute(self.metric, self.observed, self.reference) except ValueError as exc: - return utils.CheckResult.failure(str(exc)) + return utils.CheckResult.failure(f"{self.name}: {exc}") if score < self.min_similarity: return utils.CheckResult.failure( @@ -498,7 +542,7 @@ def check(self, ctx: ExperimentRunsContext) -> utils.CheckResult: self.reference, nan_equal=self.nan_equal) except ValueError as exc: - return utils.CheckResult.failure(str(exc)) + return utils.CheckResult.failure(f"{self.name}: {exc}") if all(c.result for c in comps): return utils.CheckResult.success() @@ -534,6 +578,9 @@ class ElementwiseSimilarityThresholdRequirement(utils.BaseRequirement): def __post_init__(self) -> None: if not math.isfinite(self.threshold): raise ValueError(f"{self.name}: threshold must be finite") + if not (0.0 <= self.threshold <= 1.0): + raise ValueError( + f"{self.name}: threshold must be in [0, 1], got {self.threshold!r}") if self.abs_epsilon <= 0: raise ValueError(f"{self.name}: abs_epsilon must be > 0") if self.max_mismatches_to_report <= 0: @@ -550,7 +597,7 @@ def check(self, ctx: ExperimentRunsContext) -> utils.CheckResult: abs_epsilon=self.abs_epsilon, ) except ValueError as exc: - return utils.CheckResult.failure(str(exc)) + return utils.CheckResult.failure(f"{self.name}: {exc}") if all(s.result >= self.threshold for s in scores): return utils.CheckResult.success() diff --git a/benchmarks/arteval_bench/src/evaluator/utils.py b/benchmarks/arteval_bench/src/evaluator/utils.py index 18afb499..c83de7cb 100644 --- a/benchmarks/arteval_bench/src/evaluator/utils.py +++ b/benchmarks/arteval_bench/src/evaluator/utils.py @@ -6,6 +6,7 @@ from __future__ import annotations +import abc import dataclasses import logging import os @@ -27,10 +28,9 @@ Version = typing.Tuple[int, int, int] - # ------------------------------------------------------------------------------ # Shared config helpers -# ---------------------------- +# ------------------------------------------------------------------------------ @dataclasses.dataclass(frozen=True) @@ -49,15 +49,18 @@ class EntryConfig: name: str home_dir: pathlib.Path - repository_paths: typing.Dict[str, pathlib.Path] = typing.field( + repository_paths: typing.Dict[str, pathlib.Path] = dataclasses.field( + default_factory=dict) + results_paths: typing.Dict[str, pathlib.Path] = dataclasses.field( default_factory=dict) - results_paths: typing.Dict[str, - pathlib.Path] = typing.field(default_factory=dict) - ground_truth_paths: typing.Dict[str, pathlib.Path] = typing.field( + ground_truth_paths: typing.Dict[str, pathlib.Path] = dataclasses.field( default_factory=dict) similarity_ratio: float = 0.75 + metadata: typing.Dict[str, + typing.Any] = dataclasses.field(default_factory=dict) + @dataclasses.dataclass(frozen=True, slots=True) class CheckResult: @@ -172,9 +175,9 @@ def check(self, ctx: BenchmarkContext) -> utils.CheckResult: raise NotImplementedError -# ---------------------------- +# ------------------------------------------------------------------------------ # Logging helpers -# ---------------------------- +# ------------------------------------------------------------------------------ @dataclasses.dataclass(frozen=True, slots=True) @@ -222,11 +225,9 @@ def log_result_details(logger: logging.Logger, result: CheckResult) -> None: def _is_console_handler(h: logging.Handler) -> bool: """Checks if a logging handler targets the standard console output.""" - return ( - isinstance(h, logging.StreamHandler) - and not isinstance(h, logging.FileHandler) - and getattr(h, "stream", None) in (sys.stdout, sys.stderr) - ) + return (isinstance(h, logging.StreamHandler) and + not isinstance(h, logging.FileHandler) and + getattr(h, "stream", None) in (sys.stdout, sys.stderr)) def get_logger(config: LoggerConfig, @@ -252,9 +253,9 @@ def get_logger(config: LoggerConfig, return root -# ---------------------------- -# Oracles report helpers -# ---------------------------- +# ------------------------------------------------------------------------------ +# Oracle report helpers +# ------------------------------------------------------------------------------ class _RequirementLike(typing.Protocol): @@ -364,9 +365,9 @@ def record_result( return score -# ---------------------------- +# ------------------------------------------------------------------------------ # Misc helpers -# ---------------------------- +# ------------------------------------------------------------------------------ def decode_text(value: object | None) -> str: