Skip to content

Commit c33f5b2

Browse files
committed
Gate solver-dependent viz tests and assert catalog hard-fail
1 parent 6e70c7d commit c33f5b2

2 files changed

Lines changed: 145 additions & 13 deletions

File tree

tests/integration/l5_full_pipeline/test_full_graph.py

Lines changed: 122 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,8 @@ def test_state_initialization(self, tmp_path):
251251

252252
@pytest.mark.integration_full
253253
@pytest.mark.slow
254+
@pytest.mark.requires_solver("PeleC")
255+
@pytest.mark.requires_repos("PeleC")
254256
class TestVisualizationParameterExtraction:
255257
"""
256258
Visualization parameter extraction end-to-end.
@@ -271,23 +273,40 @@ def _plan(
271273
) -> SimulationPlan:
272274
return SimulationPlan(
273275
selected_solver="PeleC",
274-
selected_case="PeleC/Exec/RegTests/PMF",
276+
selected_case="Exec/RegTests/PMF",
275277
modifications=[],
276278
reasoning="Visualization extraction baseline test",
277279
baseline_confidence=0.9,
278280
prompt=prompt,
279281
baseline={
280282
"code_name": "PeleC",
281-
"repo_path": str(baseline_dir.parents[3]),
283+
"repo_path": str(baseline_dir.parents[2]),
282284
"case_path": "Exec/RegTests/PMF",
283285
"local_path": str(baseline_dir),
284286
},
285287
visualization=visualization or {},
286288
)
287289

288290
def _make_baseline_dir(self, tmp_path: Path) -> Path:
291+
repo_root = tmp_path / "PeleC"
289292
baseline = tmp_path / "PeleC" / "Exec" / "RegTests" / "PMF"
290293
baseline.mkdir(parents=True, exist_ok=True)
294+
source_dir = repo_root / "Source"
295+
source_dir.mkdir(parents=True, exist_ok=True)
296+
(source_dir / "Setup.cpp").write_text(
297+
"\n".join(
298+
[
299+
"void setup() {",
300+
' name[cnt] = "density";',
301+
' name[cnt] = "Temp";',
302+
' derive_lst.add("magvel");',
303+
' derive_lst.add("magvort");',
304+
' derive_lst.add("z_velocity");',
305+
"}",
306+
"",
307+
]
308+
)
309+
)
291310
(baseline / "AMReX.ex").write_text("#!/bin/bash\nexit 0\n")
292311
(baseline / "inputs").write_text(
293312
"\n".join(
@@ -307,7 +326,7 @@ def _run_pipeline(self, prompt: str, tmp_path: Path, visualization: dict | None
307326
config = AMReXAgentConfig()
308327
config.output_dir = tmp_path / "output"
309328
config.environment = "perlmutter"
310-
config.repositories = {"PeleC": baseline_dir.parents[3]}
329+
config.repositories = {"PeleC": baseline_dir.parents[2]}
311330
config.run_mode = "dry_run"
312331
config.dry_run = True
313332

@@ -399,7 +418,7 @@ def test_temperature_in_prompt_produces_plotfile_var(
399418
visualization={"quantities": ["temperature"]},
400419
)
401420
plot_vars = self._plot_vars(self._read_inputs_text(final_state))
402-
assert "temperature" in plot_vars
421+
assert "Temp" in plot_vars
403422

404423
def test_multiple_quantities_all_in_plotfile_vars(
405424
self, tmp_path):
@@ -420,9 +439,9 @@ def test_multiple_quantities_all_in_plotfile_vars(
420439
},
421440
)
422441
plot_vars = self._plot_vars(self._read_inputs_text(final_state))
423-
assert "temperature" in plot_vars
424-
assert "velocity" in plot_vars
425-
assert "vorticity" in plot_vars
442+
assert "Temp" in plot_vars
443+
assert "magvel" in plot_vars
444+
assert "magvort" in plot_vars
426445

427446
def test_log_scale_in_prompt_sets_viz_metadata(
428447
self, tmp_path):
@@ -492,5 +511,99 @@ def test_squall_line_with_viz_params(self, tmp_path):
492511
},
493512
)
494513
plot_vars = self._plot_vars(self._read_inputs_text(final_state))
495-
assert "vertical_velocity" in plot_vars
496-
assert "temperature" in plot_vars
514+
assert "z_velocity" in plot_vars
515+
assert "Temp" in plot_vars
516+
517+
518+
@pytest.mark.integration_full
519+
class TestVisualizationCatalogHardFail:
520+
def test_missing_solver_catalog_hard_fails_before_inputs_write(self, tmp_path):
521+
baseline = tmp_path / "PeleC" / "Exec" / "RegTests" / "PMF"
522+
baseline.mkdir(parents=True, exist_ok=True)
523+
(baseline / "inputs").write_text("amr.plot_vars = density pressure\n")
524+
(baseline / "AMReX.ex").write_text("#!/bin/bash\nexit 0\n")
525+
526+
config = AMReXAgentConfig()
527+
config.output_dir = tmp_path / "output"
528+
config.environment = "perlmutter"
529+
config.repositories = {"PeleC": baseline.parents[2]}
530+
config.run_mode = "dry_run"
531+
config.dry_run = True
532+
533+
class DummyEmbeddingService:
534+
embeddings = None
535+
536+
class DummyRunner:
537+
def __init__(self, _config):
538+
pass
539+
540+
def setup_job(self, output_dir, case_dir, inputs_path=None):
541+
return {"run_dir": output_dir, "executable": "AMReX.ex"}
542+
543+
def submit(self, run_directory, nodes=None, run_mode=None, dry_run=None, case_dir=None):
544+
return {
545+
"job_id": "viz_catalog_fail_123",
546+
"method": "sbatch",
547+
"script_path": str(Path(run_directory) / "submit.sh"),
548+
"job_status": "completed",
549+
}
550+
551+
with patch("src.services.embedding_service_factory.get_embedding_service", return_value=DummyEmbeddingService()), \
552+
patch("src.nodes.architect_node.ArchitectService") as MockArch, \
553+
patch("src.nodes.reviewer_node.ReviewerOrchestrator") as MockRev, \
554+
patch("src.services.cases.AMReXCasesService", return_value=object()), \
555+
patch("src.nodes.runner_node.SuperfacilityRunner", DummyRunner), \
556+
patch("src.nodes.analysis_node.AnalysisService") as MockAnalysis:
557+
558+
MockArch.return_value.execute_planning.return_value = SimulationPlan(
559+
selected_solver="PeleC",
560+
selected_case="Exec/RegTests/PMF",
561+
modifications=[],
562+
reasoning="catalog hard-fail integration test",
563+
baseline_confidence=0.9,
564+
prompt="plot temperature",
565+
baseline={
566+
"code_name": "PeleC",
567+
"repo_path": str(baseline.parents[2]),
568+
"case_path": "Exec/RegTests/PMF",
569+
"local_path": str(baseline),
570+
},
571+
visualization={"quantities": ["temperature"]},
572+
)
573+
MockRev.return_value.validate_plan.return_value = ValidationResult(
574+
mode="proceed",
575+
violations=[],
576+
summary="ok",
577+
available_schema_params=[],
578+
)
579+
MockAnalysis.return_value.analyze_simulation.return_value = {
580+
"status": "success",
581+
"total_steps": 1,
582+
"final_time": 0.01,
583+
"issues": [],
584+
"warnings": [],
585+
"completed": True,
586+
}
587+
588+
final_state = run_agent(
589+
user_requirement="Run PMF and plot temperature.",
590+
config=config,
591+
)
592+
593+
assert final_state.get("inputs_file_path") is None
594+
errors_active = [str(err) for err in final_state.get("errors_active", [])]
595+
assert any(
596+
"Visualization mapping requires solver code-derived catalog" in err
597+
for err in errors_active
598+
)
599+
input_writer_entries = [
600+
entry
601+
for entry in final_state.get("workflow_history", [])
602+
if entry.get("node") == "input_writer"
603+
]
604+
assert input_writer_entries
605+
detail_errors = [str(err) for err in input_writer_entries[-1].get("details", {}).get("errors", [])]
606+
assert any(
607+
"Visualization mapping requires solver code-derived catalog" in err
608+
for err in detail_errors
609+
)

tests/unit/test_visualization_intent_model.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
visualization_intent_node,
66
)
77
from src.services.viz_param_extractor import VizMappingCatalogUnavailableError
8+
from unittest.mock import patch
89

910

1011
def test_visualization_intent_model_defaults():
@@ -19,10 +20,28 @@ def test_visualization_intent_model_defaults():
1920

2021

2122
def test_build_visualization_intent_from_prompt_cadence():
22-
model = build_visualization_intent(
23-
prompt="show cloud water every 2 minutes",
24-
solver_name="ERF",
25-
)
23+
class _MockConfig:
24+
@classmethod
25+
def get_viz_tier1_intents(cls):
26+
return {"cloud_water": {"aliases": ["cloud water", "cloud_water"]}}
27+
28+
@classmethod
29+
def build_viz_tier2_candidates(cls, repo_root=None):
30+
del cls, repo_root
31+
return {
32+
"cloud_water": [
33+
{"name": "qc", "aliases": ["cloud_water", "cloud water"]},
34+
]
35+
}
36+
37+
with patch(
38+
"database.configs.registry.get_config_class",
39+
lambda code_name: _MockConfig,
40+
):
41+
model = build_visualization_intent(
42+
prompt="show cloud water every 2 minutes",
43+
solver_name="ERF",
44+
)
2645
assert "qc" in model.requested_fields
2746
assert model.cadence_prompt_seconds == 120
2847
assert model.cadence_solver_time == 120.0

0 commit comments

Comments
 (0)