Skip to content

Commit bed9774

Browse files
authored
Fix loading schema files (#618)
1 parent 1c36ffe commit bed9774

File tree

3 files changed

+48
-5
lines changed

3 files changed

+48
-5
lines changed

metriq_gym/schema_validator.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,66 @@
55
from typing import Any
66
from jsonschema import validate
77
from pydantic import BaseModel, create_model, Field
8+
from importlib import resources
89

910
from metriq_gym.constants import JobType, SCHEMA_MAPPING
1011

11-
12+
SCHEMA_DIR_NAME = "schemas"
1213
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
13-
DEFAULT_SCHEMA_DIR = os.path.join(CURRENT_DIR, "schemas")
14+
DEFAULT_SCHEMA_DIR = os.path.join(CURRENT_DIR, SCHEMA_DIR_NAME)
1415
BENCHMARK_NAME_KEY = "benchmark_name"
1516

1617

18+
def _schema_resource_path(filename: str) -> str | None:
19+
"""Return a filesystem path to a bundled schema JSON file.
20+
21+
Tries importlib.resources (works for wheels / site-packages). Falls back to
22+
direct path inside the source tree if not found. Returns None if neither exists.
23+
"""
24+
# Attempt to access as a package resource
25+
try:
26+
candidate = resources.files("metriq_gym").joinpath(SCHEMA_DIR_NAME, filename)
27+
if candidate.is_file():
28+
# Convert to string path so existing loaders keep working
29+
return str(candidate)
30+
except (ModuleNotFoundError, FileNotFoundError):
31+
pass
32+
33+
# Fallback: direct path relative to source (development / editable install)
34+
direct_path = os.path.join(DEFAULT_SCHEMA_DIR, filename)
35+
if os.path.isfile(direct_path):
36+
return direct_path
37+
return None
38+
39+
1740
def load_json_file(file_path: str) -> dict:
1841
"""Load and parse a JSON file."""
1942
with open(file_path, "r") as file:
2043
return json.load(file)
2144

2245

2346
def load_schema(benchmark_name: str, schema_dir: str = DEFAULT_SCHEMA_DIR) -> dict:
24-
"""Load a JSON schema based on the benchmark name."""
47+
"""Load a JSON schema based on the benchmark name.
48+
49+
Uses package resources for installed distributions; falls back to local path.
50+
"""
2551
schema_filename = SCHEMA_MAPPING.get(JobType(benchmark_name))
2652
if not schema_filename:
2753
raise ValueError(f"Unsupported benchmark: {benchmark_name}")
2854

29-
schema_path = os.path.join(schema_dir, schema_filename)
30-
return load_json_file(schema_path)
55+
# Prefer packaged resource; allow overriding via explicit schema_dir argument
56+
if schema_dir != DEFAULT_SCHEMA_DIR:
57+
candidate = os.path.join(schema_dir, schema_filename)
58+
if not os.path.isfile(candidate):
59+
raise FileNotFoundError(f"Schema file not found: {candidate}")
60+
return load_json_file(candidate)
61+
62+
resource_path = _schema_resource_path(schema_filename)
63+
if resource_path is None:
64+
raise FileNotFoundError(
65+
f"Schema file '{schema_filename}' not found in package resources or '{DEFAULT_SCHEMA_DIR}'."
66+
)
67+
return load_json_file(resource_path)
3168

3269

3370
def create_pydantic_model(schema: dict[str, Any]) -> Any:

metriq_gym/schemas/__init__.py

Whitespace-only changes.

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,12 @@ include = [
112112
"quantum_fourier_transform*"
113113
]
114114

115+
[tool.setuptools.package-data]
116+
metriq_gym = ["schemas/*.json"]
117+
118+
[tool.setuptools]
119+
include-package-data = true
120+
115121
[tool.setuptools_scm]
116122
write_to = "metriq_gym/_version.py"
117123
write_to_template = "__version__ = \"{version}\"\n"

0 commit comments

Comments
 (0)