|
| 1 | +from dataclasses import dataclass |
1 | 2 | import logging |
2 | 3 | import pytest |
3 | 4 | from unittest.mock import MagicMock, patch |
4 | 5 |
|
5 | 6 | from qbraid import QbraidError |
| 7 | +from qbraid.runtime import JobStatus |
| 8 | +from metriq_gym.benchmarks.benchmark import BenchmarkData, BenchmarkResult |
6 | 9 | from metriq_gym.run import ( |
7 | 10 | setup_device, |
8 | 11 | dispatch_job, |
| 12 | + fetch_result, |
9 | 13 | ) |
| 14 | +from metriq_gym.job_manager import MetriqGymJob, JobManager |
| 15 | +from metriq_gym.constants import JobType |
10 | 16 | from metriq_gym.exceptions import QBraidSetupError |
11 | 17 |
|
12 | 18 |
|
@@ -113,3 +119,100 @@ def test_dispatch_missing_config_file(mock_exists, mock_args, mock_job_manager, |
113 | 119 | # Verify output shows file not found errors |
114 | 120 | captured = capsys.readouterr() |
115 | 121 | assert "Configuration file not found" in captured.out |
| 122 | + |
| 123 | + |
| 124 | +class DummyResult(BenchmarkResult): |
| 125 | + value: int |
| 126 | + |
| 127 | + |
| 128 | +@dataclass |
| 129 | +class DummyJobData(BenchmarkData): |
| 130 | + provider_job_ids: list[str] |
| 131 | + |
| 132 | + |
| 133 | +class DummyQuantumJob: |
| 134 | + def __init__(self, job_id: str, value: int): |
| 135 | + self.id = job_id |
| 136 | + self._value = value |
| 137 | + |
| 138 | + def status(self): |
| 139 | + return JobStatus.COMPLETED |
| 140 | + |
| 141 | + def result(self): |
| 142 | + class _R: |
| 143 | + data = {"value": self._value} |
| 144 | + |
| 145 | + return _R() |
| 146 | + |
| 147 | + |
| 148 | +class DummyBenchmark: |
| 149 | + def poll_handler(self, job_data, result_data, quantum_jobs): |
| 150 | + return DummyResult(value=result_data[0]["value"]) |
| 151 | + |
| 152 | + |
| 153 | +def _make_cached_job(val: int) -> MetriqGymJob: |
| 154 | + return MetriqGymJob( |
| 155 | + id="job-1", |
| 156 | + job_type=JobType.WIT, |
| 157 | + params={"benchmark_name": JobType.WIT.name}, |
| 158 | + data={"provider_job_ids": ["provider-job-1"]}, |
| 159 | + provider_name="local", |
| 160 | + device_name="dummy_device", |
| 161 | + platform={}, |
| 162 | + dispatch_time=None, |
| 163 | + result_data={"value": val}, |
| 164 | + ) |
| 165 | + |
| 166 | + |
| 167 | +def test_fetch_result_uses_cache_when_no_flag(monkeypatch): |
| 168 | + EXPECTED_CACHED_VALUE = 7 |
| 169 | + job = _make_cached_job(EXPECTED_CACHED_VALUE) |
| 170 | + jm = JobManager() |
| 171 | + jm.jobs.append(job) |
| 172 | + args = MagicMock() |
| 173 | + args.no_cache = False |
| 174 | + |
| 175 | + import metriq_gym.run as run_mod |
| 176 | + |
| 177 | + monkeypatch.setattr(run_mod, "setup_benchmark_result_class", lambda *_: DummyResult) |
| 178 | + monkeypatch.setattr(run_mod, "setup_job_data_class", lambda *_: DummyJobData) |
| 179 | + monkeypatch.setattr(run_mod, "setup_benchmark", lambda *_, **__: DummyBenchmark()) |
| 180 | + monkeypatch.setattr( |
| 181 | + run_mod, |
| 182 | + "load_job", |
| 183 | + lambda *_, **__: DummyQuantumJob("provider-job-1", EXPECTED_CACHED_VALUE), |
| 184 | + ) |
| 185 | + monkeypatch.setattr(run_mod, "validate_and_create_model", lambda params: params) |
| 186 | + |
| 187 | + result = fetch_result(job, args, jm) |
| 188 | + assert result.value == EXPECTED_CACHED_VALUE |
| 189 | + |
| 190 | + |
| 191 | +def test_fetch_result_bypasses_cache_with_flag(monkeypatch): |
| 192 | + EXPECTED_FRESH_VALUE = 42 |
| 193 | + CACHED_VALUE = 7 |
| 194 | + job = _make_cached_job(CACHED_VALUE) |
| 195 | + jm = JobManager() |
| 196 | + jm.jobs.append(job) |
| 197 | + args = MagicMock() |
| 198 | + args.no_cache = True |
| 199 | + |
| 200 | + import metriq_gym.run as run_mod |
| 201 | + |
| 202 | + monkeypatch.setattr(run_mod, "setup_benchmark_result_class", lambda *_: DummyResult) |
| 203 | + monkeypatch.setattr(run_mod, "setup_job_data_class", lambda *_: DummyJobData) |
| 204 | + monkeypatch.setattr(run_mod, "setup_benchmark", lambda *_, **__: DummyBenchmark()) |
| 205 | + monkeypatch.setattr( |
| 206 | + run_mod, |
| 207 | + "load_job", |
| 208 | + lambda *_, **__: DummyQuantumJob("provider-job-1", EXPECTED_FRESH_VALUE), |
| 209 | + ) |
| 210 | + monkeypatch.setattr(run_mod, "validate_and_create_model", lambda params: params) |
| 211 | + |
| 212 | + result = fetch_result(job, args, jm) |
| 213 | + assert result.value == EXPECTED_FRESH_VALUE, ( |
| 214 | + "Should fetch fresh value when --no-cache specified" |
| 215 | + ) |
| 216 | + assert job.result_data == {"value": EXPECTED_FRESH_VALUE}, ( |
| 217 | + "Cached result_data should be updated" |
| 218 | + ) |
0 commit comments