|
5 | 5 | from typing import Any |
6 | 6 | from jsonschema import validate |
7 | 7 | from pydantic import BaseModel, create_model, Field |
| 8 | +from importlib import resources |
8 | 9 |
|
9 | 10 | from metriq_gym.constants import JobType, SCHEMA_MAPPING |
10 | 11 |
|
11 | | - |
| 12 | +SCHEMA_DIR_NAME = "schemas" |
12 | 13 | CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) |
13 | | -DEFAULT_SCHEMA_DIR = os.path.join(CURRENT_DIR, "schemas") |
| 14 | +DEFAULT_SCHEMA_DIR = os.path.join(CURRENT_DIR, SCHEMA_DIR_NAME) |
14 | 15 | BENCHMARK_NAME_KEY = "benchmark_name" |
15 | 16 |
|
16 | 17 |
|
| 18 | +def _schema_resource_path(filename: str) -> str | None: |
| 19 | + """Return a filesystem path to a bundled schema JSON file. |
| 20 | +
|
| 21 | + Tries importlib.resources (works for wheels / site-packages). Falls back to |
| 22 | + direct path inside the source tree if not found. Returns None if neither exists. |
| 23 | + """ |
| 24 | + # Attempt to access as a package resource |
| 25 | + try: |
| 26 | + candidate = resources.files("metriq_gym").joinpath(SCHEMA_DIR_NAME, filename) |
| 27 | + if candidate.is_file(): |
| 28 | + # Convert to string path so existing loaders keep working |
| 29 | + return str(candidate) |
| 30 | + except (ModuleNotFoundError, FileNotFoundError): |
| 31 | + pass |
| 32 | + |
| 33 | + # Fallback: direct path relative to source (development / editable install) |
| 34 | + direct_path = os.path.join(DEFAULT_SCHEMA_DIR, filename) |
| 35 | + if os.path.isfile(direct_path): |
| 36 | + return direct_path |
| 37 | + return None |
| 38 | + |
| 39 | + |
17 | 40 | def load_json_file(file_path: str) -> dict: |
18 | 41 | """Load and parse a JSON file.""" |
19 | 42 | with open(file_path, "r") as file: |
20 | 43 | return json.load(file) |
21 | 44 |
|
22 | 45 |
|
23 | 46 | def load_schema(benchmark_name: str, schema_dir: str = DEFAULT_SCHEMA_DIR) -> dict: |
24 | | - """Load a JSON schema based on the benchmark name.""" |
| 47 | + """Load a JSON schema based on the benchmark name. |
| 48 | +
|
| 49 | + Uses package resources for installed distributions; falls back to local path. |
| 50 | + """ |
25 | 51 | schema_filename = SCHEMA_MAPPING.get(JobType(benchmark_name)) |
26 | 52 | if not schema_filename: |
27 | 53 | raise ValueError(f"Unsupported benchmark: {benchmark_name}") |
28 | 54 |
|
29 | | - schema_path = os.path.join(schema_dir, schema_filename) |
30 | | - return load_json_file(schema_path) |
| 55 | + # Prefer packaged resource; allow overriding via explicit schema_dir argument |
| 56 | + if schema_dir != DEFAULT_SCHEMA_DIR: |
| 57 | + candidate = os.path.join(schema_dir, schema_filename) |
| 58 | + if not os.path.isfile(candidate): |
| 59 | + raise FileNotFoundError(f"Schema file not found: {candidate}") |
| 60 | + return load_json_file(candidate) |
| 61 | + |
| 62 | + resource_path = _schema_resource_path(schema_filename) |
| 63 | + if resource_path is None: |
| 64 | + raise FileNotFoundError( |
| 65 | + f"Schema file '{schema_filename}' not found in package resources or '{DEFAULT_SCHEMA_DIR}'." |
| 66 | + ) |
| 67 | + return load_json_file(resource_path) |
31 | 68 |
|
32 | 69 |
|
33 | 70 | def create_pydantic_model(schema: dict[str, Any]) -> Any: |
|
0 commit comments