Skip to content

Commit 6e70c7d

Browse files
committed
Implement code-gated viz mapping with clarification disambiguation
1 parent f40a532 commit 6e70c7d

20 files changed

Lines changed: 779 additions & 289 deletions

database/configs/base_amrex_config.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2384,6 +2384,68 @@ def get_viz_variable_catalog(cls, repo_root: Path | None = None) -> list[dict[st
23842384
return []
23852385

23862386
@classmethod
2387+
def get_viz_tier1_intents(cls) -> dict[str, dict[str, Any]]:
2388+
"""
2389+
Return canonical semantic visualization intents (Tier 1).
2390+
"""
2391+
return {
2392+
"temperature": {"aliases": ["temperature", "temp", "thermal"]},
2393+
"velocity": {"aliases": ["velocity", "speed"]},
2394+
"vertical_velocity": {
2395+
"aliases": [
2396+
"vertical velocity",
2397+
"vertical_velocity",
2398+
"w-velocity",
2399+
"w velocity",
2400+
"updraft",
2401+
"downdraft",
2402+
]
2403+
},
2404+
"pressure": {"aliases": ["pressure", "pres"]},
2405+
"density": {"aliases": ["density", "rho"]},
2406+
"vorticity": {"aliases": ["vorticity", "vort"]},
2407+
"cloud_water": {
2408+
"aliases": [
2409+
"cloud water",
2410+
"cloud_water",
2411+
"liquid water",
2412+
"cloud liquid",
2413+
"qc",
2414+
]
2415+
},
2416+
}
2417+
2418+
@classmethod
2419+
def build_viz_tier2_candidates(cls, repo_root: Path | None = None) -> dict[str, list[dict[str, Any]]]:
2420+
"""
2421+
Build solver candidates (Tier 2) from live source catalog.
2422+
"""
2423+
catalog = cls.get_viz_variable_catalog(repo_root=repo_root)
2424+
if not catalog:
2425+
return {}
2426+
2427+
intents = cls.get_viz_tier1_intents()
2428+
candidates: dict[str, list[dict[str, Any]]] = {}
2429+
for token, spec in intents.items():
2430+
aliases = {token.lower()}
2431+
for alias in spec.get("aliases", []) or []:
2432+
aliases.add(str(alias).strip().lower())
2433+
token_candidates: list[dict[str, Any]] = []
2434+
for entry in catalog:
2435+
if not isinstance(entry, dict):
2436+
continue
2437+
name = str(entry.get("name", "")).strip()
2438+
if not name:
2439+
continue
2440+
entry_aliases = {name.lower()}
2441+
for alias in entry.get("aliases", []) or []:
2442+
entry_aliases.add(str(alias).strip().lower())
2443+
if aliases.intersection(entry_aliases):
2444+
token_candidates.append(dict(entry))
2445+
if token_candidates:
2446+
candidates[token] = token_candidates
2447+
return candidates
2448+
23872449
def get_default_slice_axis(cls) -> str | None:
23882450
"""
23892451
Return preferred default slice-normal axis for visualization.
@@ -2399,6 +2461,13 @@ def get_plotfile_var_param(cls) -> str:
23992461
"""
24002462
return "amr.plot_vars"
24012463

2464+
@classmethod
2465+
def get_plot_var_param_candidates(cls) -> list[str]:
2466+
"""
2467+
Return ordered plot-var ParmParse candidate keys (primary first).
2468+
"""
2469+
return [cls.get_plotfile_var_param()]
2470+
24022471
@classmethod
24032472
def get_plotfile_period_param(cls) -> str | None:
24042473
"""

database/configs/erf_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def get_plotfile_period_param(cls) -> str | None:
4141
def get_plotfile_step_interval_param(cls) -> str | None:
4242
return "erf.plot_int_1"
4343

44+
@classmethod
45+
def get_plot_var_param_candidates(cls) -> list[str]:
46+
return ["erf.plot_vars_1", "amr.plot_vars", "plot_vars"]
47+
4448
# === Registry Metadata ===
4549
github_org = "erf-model"
4650
github_repo = "ERF"

database/configs/pelelmex_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,10 @@ class PeleLMeXConfig(BaseAMReXConfig):
171171
def get_plotfile_var_param(cls) -> str:
172172
return "peleLM.derive_plot_vars"
173173

174+
@classmethod
175+
def get_plot_var_param_candidates(cls) -> list[str]:
176+
return ["peleLM.derive_plot_vars", "amr.plot_vars", "plot_vars"]
177+
174178
@classmethod
175179
def get_plotfile_period_param(cls) -> str | None:
176180
return "amr.plot_per"

docs/intent_runtime_routing.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,26 @@ mixing it into inputs-file modifications.
1818
Runtime intent is consumed by runner and visualization paths. It is not passed
1919
into `ConfigModelFactory.apply_modifications`.
2020

21+
## Visualization Mapping Contract
22+
23+
Visualization field resolution is code-gated and tiered:
24+
25+
1. Tier 1 semantic intents: canonical tokens extracted from prompt text (for
26+
example `cloud_water`, `temperature`, `velocity`).
27+
2. Tier 2 solver candidates: live solver-code catalog candidates resolved from
28+
`build_viz_tier2_candidates(repo_root)`.
29+
3. Tier 3 executable selection: final `requested_fields` consumed by input
30+
writing/visualization, selected only from Tier 2 candidates.
31+
32+
Policy:
33+
34+
- Missing solver code/catalog during visualization mapping is a hard runtime
35+
failure with remediation guidance.
36+
- No blind fallback to unresolved generic semantic tokens in final
37+
`requested_fields`.
38+
- Ambiguity or no-match is routed to clarification with candidate-constrained
39+
question payloads.
40+
2141
## Precedence
2242

2343
Effective runtime values are resolved in this order:

docs/workflow_visualization.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,27 @@ The workflow generates visualizations from AMReX plotfiles after a successful
1212
analysis pass. Visualization uses the multi-backend service (AMReX tools first,
1313
then optional pyamrex, then yt) and records metadata in the graph state.
1414

15+
### Intent Resolution and Clarification
16+
17+
Visualization intent resolution is deterministic-first and code-derived:
18+
19+
- Architect emits semantic visualization intent only.
20+
- `visualization_intent_node` performs canonical mapping through solver
21+
Tier-2 candidates from live code catalogs.
22+
- If mapping is ambiguous/unresolved, the node records diagnostics:
23+
- `visualization_mapping_candidates`
24+
- `visualization_mapping_unresolved`
25+
- `visualization_mapping_source`
26+
- `visualization_mapping_confidence`
27+
- Clarification node asks constrained disambiguation questions using provided
28+
candidate sets, and AI responses outside the candidate set are rejected.
29+
30+
Hard-fail behavior:
31+
32+
- If solver catalog/source code is unavailable for a visualization mapping
33+
request, runtime fails with actionable remediation instead of silently
34+
guessing fields.
35+
1536
### Workflow Flow
1637

1738
```

src/models/graph_state_canonical.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ class GraphState(TypedDict, total=False):
5555
requested_plot_vars: Optional[List[str]] # Prompt-extracted plot quantities
5656
visualization_config: Optional[Dict[str, Any]] # Prompt-extracted visualization settings
5757
visualization_intent: Optional[Dict[str, Any]] # Canonical visualization intent contract
58+
visualization_mapping_candidates: Optional[Dict[str, Any]]
59+
visualization_mapping_unresolved: Optional[List[str]]
60+
visualization_mapping_source: Optional[str]
61+
visualization_mapping_confidence: Optional[float]
5862

5963
# ========================================
6064
# PLANNING PHASE (Architect outputs)

src/nodes/clarification_handler_node.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ def clarification_handler_node(state: dict[str, Any]) -> dict[str, Any]:
4848
"clarification_turns": turns + 1,
4949
"skip_further_clarification": False,
5050
}
51+
if answered_by == "ai_agent" and not _is_allowed_candidate_response(
52+
pending_record.get("question", {}),
53+
response,
54+
):
55+
return {
56+
"clarification_history": history,
57+
"clarification_needed": True,
58+
"clarification_turns": turns + 1,
59+
"skip_further_clarification": False,
60+
}
5161

5262
updated_record = _resolve_record(pending_record, response, answered_by)
5363
history[pending_index] = updated_record
@@ -177,6 +187,18 @@ def _merge_plot_vars(current: list[Any], resolved_value: str) -> list[str]:
177187
return merged
178188

179189

190+
def _is_allowed_candidate_response(question: dict[str, Any], response: str) -> bool:
191+
question_data = _question_to_dict(question)
192+
context = question_data.get("context", {})
193+
if not isinstance(context, dict):
194+
return True
195+
allowed = context.get("candidates")
196+
if not isinstance(allowed, list) or not allowed:
197+
return True
198+
allowed_set = {str(v).strip() for v in allowed if str(v).strip()}
199+
return str(response).strip() in allowed_set
200+
201+
180202
def _as_history(value: Any) -> list[dict[str, Any]]:
181203
if not isinstance(value, list):
182204
return []

src/nodes/clarification_node.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66

77
from src.models import GraphState
88
from src.models.clarification_schemas import ClarificationQuestion, ClarificationRecord
9+
from src.services.viz_param_extractor import get_plot_var_param_candidates
910

1011

1112
REQUIRED_FIELDS = ("n_cell", "max_level", "stop_time", "max_step")
1213
RESOURCE_FIELDS = ("node_count", "time_limit", "cluster")
13-
PLOTFILE_VAR_KEYS = ("amr.plot_vars", "peleLM.derive_plot_vars", "plot_vars")
1414
MAX_CLARIFICATION_TURNS = 3
1515

1616

@@ -261,8 +261,41 @@ def _check_level_4_visualization(
261261
prompt: str,
262262
requested_plot_vars: Any,
263263
) -> ClarificationQuestion | None:
264-
del state, locked_fields, prompt
265-
if requested_plot_vars == [] and not _has_plotfile_var(resolved):
264+
del locked_fields, prompt
265+
mapping_candidates = state.get("visualization_mapping_candidates")
266+
mapping_unresolved = state.get("visualization_mapping_unresolved")
267+
if isinstance(mapping_candidates, dict) and mapping_candidates:
268+
token = next(iter(mapping_candidates.keys()))
269+
entries = mapping_candidates.get(token, [])
270+
names: list[str] = []
271+
if isinstance(entries, list):
272+
for entry in entries:
273+
if isinstance(entry, dict):
274+
name = str(entry.get("name", "")).strip()
275+
if name:
276+
names.append(name)
277+
if names:
278+
return ClarificationQuestion(
279+
field_name="requested_plot_vars",
280+
question_text=f"Which field should represent '{token}' in visualization output?",
281+
decision_level=4,
282+
fallback_tier="amrex_generic",
283+
context={
284+
"reason": "visualization_mapping_ambiguity",
285+
"token": token,
286+
"candidates": names,
287+
},
288+
)
289+
if isinstance(mapping_unresolved, list) and mapping_unresolved:
290+
token = str(mapping_unresolved[0]).strip()
291+
return ClarificationQuestion(
292+
field_name="requested_plot_vars",
293+
question_text=f"No solver field matched '{token}'. Which plot variable should be used?",
294+
decision_level=4,
295+
fallback_tier="amrex_generic",
296+
context={"reason": "visualization_mapping_unresolved", "token": token, "candidates": []},
297+
)
298+
if requested_plot_vars == [] and not _has_plotfile_var(resolved, state):
266299
return ClarificationQuestion(
267300
field_name="plot_vars",
268301
question_text="Which variables should be written to plotfiles for visualization?",
@@ -386,7 +419,7 @@ def _base_context(
386419
"missing_fields": missing_fields,
387420
"ambiguous_fields": [],
388421
"requested_plot_vars_empty": requested_plot_vars == [],
389-
"missing_plotfile_vars": requested_plot_vars == [] and not _has_plotfile_var(resolved),
422+
"missing_plotfile_vars": requested_plot_vars == [] and not _has_plotfile_var(resolved, None),
390423
}
391424

392425

@@ -406,8 +439,12 @@ def _fallback_tier_for_field(field: str) -> str:
406439
return "free_text"
407440

408441

409-
def _has_plotfile_var(resolved: dict[str, Any]) -> bool:
410-
for key in PLOTFILE_VAR_KEYS:
442+
def _has_plotfile_var(resolved: dict[str, Any], state: GraphState | None) -> bool:
443+
solver = ""
444+
if isinstance(state, dict):
445+
solver = str(state.get("selected_solver") or "").strip()
446+
keys = get_plot_var_param_candidates(solver) + ["plot_vars"]
447+
for key in keys:
411448
if _has_value(resolved.get(key)):
412449
return True
413450
return False

src/nodes/visualization_intent_node.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from src.models import GraphState
1313
from src.models.visualization_intent import VisualizationIntent, VisualizationPlotSpec
1414
from src.services.viz_param_extractor import (
15-
canonicalize_requested_plot_vars,
1615
convert_prompt_seconds_to_solver_time,
1716
extract_viz_params_from_prompt,
17+
resolve_viz_field_mapping,
1818
)
1919

2020

@@ -84,6 +84,7 @@ def build_visualization_intent(
8484
requested_plot_vars: list[str] | None = None,
8585
visualization_config: dict[str, Any] | None = None,
8686
prior_intent: dict[str, Any] | None = None,
87+
mapping_diagnostics_out: dict[str, Any] | None = None,
8788
) -> VisualizationIntent:
8889
"""
8990
Build a canonical visualization_intent payload.
@@ -112,12 +113,21 @@ def build_visualization_intent(
112113
+ list(prior_fields)
113114
+ list(prior_vars)
114115
)
115-
if solver_name:
116-
merged_requested = canonicalize_requested_plot_vars(
116+
mapping_diagnostics: dict[str, Any] = {
117+
"resolved_fields": list(merged_requested),
118+
"candidate_fields_by_token": {},
119+
"unresolved_tokens": [],
120+
"ambiguous_tokens": [],
121+
"mapping_source": "semantic_only" if merged_requested else "none",
122+
"mapping_confidence": 1.0,
123+
}
124+
if solver_name and merged_requested:
125+
mapping_diagnostics = resolve_viz_field_mapping(
117126
merged_requested,
118127
code_name=solver_name,
119128
repo_root=repo_root,
120129
)
130+
merged_requested = list(mapping_diagnostics.get("resolved_fields", []))
121131

122132
merged_config: dict[str, Any] = {}
123133
if isinstance(extracted_config, dict):
@@ -184,6 +194,10 @@ def build_visualization_intent(
184194
if source not in {"prompt", "clarification", "default"}:
185195
source = "default"
186196

197+
if isinstance(mapping_diagnostics_out, dict):
198+
mapping_diagnostics_out.clear()
199+
mapping_diagnostics_out.update(mapping_diagnostics)
200+
187201
return VisualizationIntent(
188202
requested_fields=merged_requested,
189203
cadence_prompt_seconds=cadence_prompt_seconds,
@@ -248,18 +262,24 @@ def visualization_intent_node(state: GraphState) -> dict[str, Any]:
248262
prompt = str(state.get("prompt") or state.get("user_requirement") or "")
249263
prior_intent = state.get("visualization_intent") if isinstance(state.get("visualization_intent"), dict) else {}
250264

265+
mapping: dict[str, Any] = {}
251266
model = build_visualization_intent(
252267
prompt=prompt,
253268
solver_name=solver_name,
254269
repo_root=repo_root,
255270
requested_plot_vars=[],
256271
visualization_config={},
257272
prior_intent=prior_intent,
273+
mapping_diagnostics_out=mapping,
258274
)
259275
intent = model.model_dump()
260276

261277
return {
262278
"visualization_intent": intent,
263279
"requested_plot_vars": list(model.requested_fields),
264280
"visualization_config": dict(intent.get("visualization_config", {})),
281+
"visualization_mapping_candidates": dict(mapping.get("candidate_fields_by_token", {})),
282+
"visualization_mapping_unresolved": list(mapping.get("unresolved_tokens", [])),
283+
"visualization_mapping_source": mapping.get("mapping_source"),
284+
"visualization_mapping_confidence": mapping.get("mapping_confidence"),
265285
}

0 commit comments

Comments
 (0)