Skip to content

Commit 16c491d

Browse files
committed
fix: web adapter pre-evaluates dependencies via Python engine
The RegelrechtMachineService web adapter now pre-evaluates cross-law dependencies using the Python engine (matching the BDD wrapper approach) instead of passing raw YAML as extra_laws. This fixes "Variable not found" errors when evaluating laws with cross-law references (zorgtoeslag → wet_brp, wet_inkomstenbelasting). Also adds TypeSpec enforcement, voldoet_aan_voorwaarden handling, temporal reference resolution, and numpy type conversion. Tested: zorgtoeslag evaluates correctly in web UI showing €1.772,62.
1 parent 1158bed commit 16c491d

1 file changed

Lines changed: 172 additions & 56 deletions

File tree

web/engines/regelrecht_engine/engine.py

Lines changed: 172 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
from datetime import date, datetime
66
from typing import Any
77

8+
import numpy as np
89
import pandas as pd
910
import yaml
1011
from fastapi import HTTPException
1112

13+
from machine.context import TypeSpec
1214
from machine.profile_loader import get_project_root, load_profiles_from_yaml
1315
from machine.service import Services
1416
from machine.utils import RuleResolver
@@ -90,8 +92,7 @@ def evaluate(
9092
with open(rule.path) as f:
9193
yaml_content = f.read()
9294

93-
# Collect cross-law dependencies
94-
extra_laws = self._collect_extra_laws(yaml_content, reference_date)
95+
parsed_yaml = yaml.safe_load(yaml_content)
9596

9697
# Build params dict from parameters
9798
params = dict(parameters)
@@ -102,36 +103,44 @@ def evaluate(
102103
if isinstance(section_values, dict):
103104
params.update(section_values)
104105

106+
# Pre-evaluate cross-law dependencies using the Python engine
107+
dep_outputs = self._pre_evaluate_dependencies(
108+
parsed_yaml, reference_date, parameters, overwrite_input, approved
109+
)
110+
params.update(dep_outputs)
111+
105112
# Pre-resolve source_reference fields from services' DataFrames
106113
source_params = self._pre_resolve_sources(yaml_content, params)
107114
params.update(source_params)
108115

116+
# Strip source references so the CLI uses params for all input values
117+
stripped_yaml = self._strip_sources(yaml_content)
118+
109119
# Get output field names from the YAML spec
110-
parsed_yaml = yaml.safe_load(yaml_content)
111120
output_names = self._get_output_names(parsed_yaml)
112121

113122
if requested_output:
114123
output_names = [requested_output]
115124

116125
# Call CLI for each output field and merge results
117-
merged_outputs = {}
118-
merged_resolved_inputs = {}
126+
merged_outputs: dict[str, Any] = {}
127+
merged_resolved_inputs: dict[str, Any] = {}
119128
last_uuid = rule.uuid
120129
had_error = False
121130
missing_required = False
122131

123132
for output_name in output_names:
124133
cli_result = self._call_cli(
125-
yaml_content=yaml_content,
134+
yaml_content=stripped_yaml,
126135
output_name=output_name,
127-
params=params,
136+
params=_convert_to_native(params),
128137
reference_date=reference_date,
129-
extra_laws=extra_laws,
138+
extra_laws=[],
130139
)
131140

132141
if "error" in cli_result:
133142
error_msg = cli_result["error"]
134-
logger.warning(f"Regelrecht CLI error for {law}/{output_name}: {error_msg}")
143+
logger.warning("Regelrecht CLI error for %s/%s: %s", law, output_name, error_msg)
135144
if "missing" in error_msg.lower() or "variable" in error_msg.lower():
136145
missing_required = True
137146
had_error = True
@@ -149,12 +158,23 @@ def evaluate(
149158
if "law_uuid" in cli_result:
150159
last_uuid = cli_result["law_uuid"]
151160

152-
enriched_output = {}
161+
# Strip voldoet_aan_voorwaarden from outputs and use it for requirements_met
162+
voldoet = merged_outputs.pop("voldoet_aan_voorwaarden", None)
163+
if voldoet is not None:
164+
requirements_met = bool(voldoet) and not had_error
165+
else:
166+
requirements_met = bool(merged_outputs) and not had_error
167+
168+
# Enforce output type specs (precision, min/max, eurocent conversion)
169+
output_specs = _build_output_specs(parsed_yaml)
170+
for key in list(merged_outputs.keys()):
171+
if key in output_specs:
172+
merged_outputs[key] = output_specs[key].enforce(merged_outputs[key])
173+
174+
enriched_output: dict[str, Any] = {}
153175
for name, value in merged_outputs.items():
154176
enriched_output[name] = value
155177

156-
requirements_met = bool(merged_outputs) and not had_error
157-
158178
return RuleResult(
159179
output=enriched_output,
160180
requirements_met=requirements_met,
@@ -250,51 +270,90 @@ def _call_cli(
250270
logger.error(f"Failed to parse CLI output as JSON: {e}")
251271
return {"error": f"Invalid JSON from CLI: {e}"}
252272

253-
# ---- Cross-law dependency collection ----
273+
# ---- Dependency pre-evaluation via Python engine ----
254274

255-
def _collect_extra_laws(self, yaml_content: str, reference_date: str) -> list[str]:
256-
"""Scan YAML for source.regulation references and collect those law YAML files.
275+
def _pre_evaluate_dependencies(
276+
self,
277+
parsed_yaml: dict,
278+
reference_date: str,
279+
parameters: dict[str, Any],
280+
overwrite_input: dict[str, Any] | None,
281+
approved: bool,
282+
) -> dict[str, Any]:
283+
"""Pre-evaluate dependency laws using the Python engine.
257284
258-
Recursively resolves dependencies so that transitive references are also included.
285+
Walks the main law's input fields looking for source.regulation references,
286+
evaluates each referenced law via the Python Services instance, and maps the
287+
output to the input name expected by the main law.
259288
"""
260-
data = yaml.safe_load(yaml_content)
261-
regulations: set[str] = set()
262-
_find_regulations(data, regulations)
289+
if not self.services:
290+
return {}
263291

264-
extra_yamls: list[str] = []
265-
visited: set[str] = set()
292+
resolved: dict[str, Any] = {}
266293

267-
# Use a work queue for recursive resolution
268-
queue = list(regulations)
269-
while queue:
270-
reg = queue.pop()
271-
if reg in visited:
272-
continue
273-
visited.add(reg)
294+
for article in parsed_yaml.get("articles", []):
295+
mr = article.get("machine_readable", {})
296+
execution = mr.get("execution", {})
297+
for inp in execution.get("input", []):
298+
input_name = inp.get("name")
299+
source = inp.get("source", {})
300+
if not source or not input_name:
301+
continue
274302

275-
try:
276-
rule = self.resolver.find_rule(reg, reference_date)
277-
except ValueError:
278-
logger.debug(f"Referenced law not found: {reg}")
279-
continue
303+
regulation = source.get("regulation")
304+
output_name = source.get("output")
305+
dep_service = source.get("service")
306+
if not regulation or not output_name:
307+
continue
280308

281-
if not rule:
282-
continue
309+
if input_name in resolved:
310+
continue
283311

284-
with open(rule.path) as f:
285-
extra_yaml = f.read()
312+
# Resolve temporal reference to adjust the reference_date
313+
dep_reference_date = _resolve_temporal_reference(inp, reference_date)
286314

287-
extra_yamls.append(extra_yaml)
315+
try:
316+
dep_result = self.services.evaluate(
317+
service=dep_service or "UNKNOWN",
318+
law=regulation,
319+
parameters=parameters,
320+
reference_date=dep_reference_date,
321+
overwrite_input=overwrite_input,
322+
requested_output=output_name,
323+
approved=approved,
324+
)
325+
if dep_result and dep_result.output:
326+
value = dep_result.output.get(output_name)
327+
if value is not None:
328+
resolved[input_name] = value
329+
except Exception as e:
330+
logger.debug(
331+
"Failed to pre-evaluate dependency %s/%s: %s",
332+
regulation,
333+
output_name,
334+
e,
335+
)
288336

289-
# Scan this dependency for further references
290-
dep_data = yaml.safe_load(extra_yaml)
291-
dep_regulations: set[str] = set()
292-
_find_regulations(dep_data, dep_regulations)
293-
for dep_reg in dep_regulations:
294-
if dep_reg not in visited:
295-
queue.append(dep_reg)
337+
return resolved
296338

297-
return extra_yamls
339+
@staticmethod
340+
def _strip_sources(yaml_content: str) -> str:
341+
"""Strip source/source_reference from inputs so CLI uses params."""
342+
data = yaml.safe_load(yaml_content)
343+
changed = False
344+
for article in data.get("articles", []):
345+
mr = article.get("machine_readable", {})
346+
execution = mr.get("execution", {})
347+
for inp in execution.get("input", []):
348+
if "source" in inp:
349+
del inp["source"]
350+
changed = True
351+
if "source_reference" in inp:
352+
del inp["source_reference"]
353+
changed = True
354+
if changed:
355+
return yaml.dump(data, default_flow_style=False, allow_unicode=True, sort_keys=False)
356+
return yaml_content
298357

299358
# ---- Source reference pre-resolution ----
300359

@@ -450,13 +509,70 @@ def _get_output_specs(data: dict) -> dict[str, dict[str, Any]]:
450509
return specs
451510

452511

453-
def _find_regulations(data: Any, regulations: set[str]) -> None:
454-
"""Recursively find all source.regulation references in a YAML structure."""
455-
if isinstance(data, dict):
456-
if "regulation" in data and isinstance(data["regulation"], str):
457-
regulations.add(data["regulation"])
458-
for value in data.values():
459-
_find_regulations(value, regulations)
460-
elif isinstance(data, list):
461-
for item in data:
462-
_find_regulations(item, regulations)
512+
def _resolve_temporal_reference(input_spec: dict, reference_date: str) -> str:
513+
"""Resolve temporal reference on an input spec to get the adjusted reference_date.
514+
515+
For example, if the input has temporal.reference = $prev_january_first and the
516+
reference_date is 2025-02-01, the resolved date is 2024-01-01.
517+
"""
518+
temporal = input_spec.get("temporal", {})
519+
ref = temporal.get("reference")
520+
if not ref or not isinstance(ref, str) or not ref.startswith("$"):
521+
return reference_date
522+
523+
ref_name = ref[1:]
524+
calc_date = datetime.strptime(reference_date, "%Y-%m-%d").date()
525+
526+
if ref_name == "prev_january_first":
527+
return calc_date.replace(month=1, day=1, year=calc_date.year - 1).isoformat()
528+
elif ref_name == "january_first":
529+
return calc_date.replace(month=1, day=1).isoformat()
530+
elif ref_name in ("calculation_date", "year"):
531+
return reference_date
532+
533+
return reference_date
534+
535+
536+
def _build_output_specs(data: dict) -> dict[str, TypeSpec]:
537+
"""Build mapping of output names to their TypeSpec from parsed YAML.
538+
539+
This mirrors the BDD wrapper so the web adapter can enforce the same
540+
precision/min/max/eurocent constraints.
541+
"""
542+
specs: dict[str, TypeSpec] = {}
543+
for article in data.get("articles", []):
544+
mr = article.get("machine_readable", {})
545+
execution = mr.get("execution", {})
546+
for out in execution.get("output", []):
547+
name = out.get("name")
548+
if name:
549+
ts = out.get("type_spec", {})
550+
specs[name] = TypeSpec(
551+
type=out.get("type"),
552+
unit=ts.get("unit"),
553+
precision=ts.get("precision"),
554+
min=ts.get("min"),
555+
max=ts.get("max"),
556+
)
557+
return specs
558+
559+
560+
def _convert_to_native(obj: Any) -> Any:
561+
"""Recursively convert numpy/pandas types to native Python types for JSON serialization."""
562+
if isinstance(obj, dict):
563+
return {k: _convert_to_native(v) for k, v in obj.items()}
564+
if isinstance(obj, list):
565+
return [_convert_to_native(v) for v in obj]
566+
if isinstance(obj, (np.integer,)):
567+
return int(obj)
568+
if isinstance(obj, (np.floating,)):
569+
return float(obj)
570+
if isinstance(obj, (np.bool_,)):
571+
return bool(obj)
572+
if isinstance(obj, np.ndarray):
573+
return obj.tolist()
574+
if isinstance(obj, pd.Timestamp):
575+
return obj.isoformat()
576+
if pd.isna(obj) if isinstance(obj, float) else False:
577+
return None
578+
return obj

0 commit comments

Comments
 (0)