22
33from __future__ import annotations
44
5- from dataclasses import dataclass
6- from typing import List , Optional
5+ from dataclasses import dataclass , field
6+ import tempfile
7+ from pathlib import Path
8+ from typing import Dict , List , Optional
79
10+ from tla_eval .evaluation .semantics .manual_invariant_evaluator import ManualInvariantEvaluator
811from tla_eval .evaluation .syntax .compilation_check import CompilationCheckEvaluator
912from tla_eval .evaluation .semantics .runtime_coverage_evaluator import RuntimeCoverageEvaluator
13+ from tla_eval .evaluation .consistency .trace_validation import TraceValidationEvaluator
1014
1115
1216@dataclass
@@ -15,31 +19,59 @@ class EvaluationOutcome:
1519
1620 compilation : object
1721 runtime : Optional [object ]
22+ invariant : Optional [object ]
23+ trace : Optional [object ]
1824 success : bool
1925 errors : List [str ]
26+ phase_scores : Dict [str , float ] = field (default_factory = dict )
27+ normalized_weights : Dict [str , float ] = field (default_factory = dict )
2028
2129
2230class SysMoEvaluator :
2331 """Wraps SysMoBench evaluators to provide a unified interface."""
2432
33+ TRACE_ENABLED_TASKS_DEFAULT = {"spin" }
34+
2535 def __init__ (
2636 self ,
2737 compilation_timeout : int = 60 ,
2838 runtime_simulations : int = 100 ,
2939 runtime_depth : int = 100 ,
3040 runtime_timeout : int = 300 ,
41+ invariant_timeout : int = 300 ,
42+ phase_weights : Optional [Dict [str , float ]] = None ,
43+ trace_enabled_tasks : Optional [List [str ]] = None ,
3144 ) -> None :
3245 self .compilation_eval = CompilationCheckEvaluator (validation_timeout = compilation_timeout )
3346 self .runtime_eval = RuntimeCoverageEvaluator (
3447 num_simulations = runtime_simulations ,
3548 simulation_depth = runtime_depth ,
3649 tlc_timeout = runtime_timeout ,
3750 )
51+ # Only initialize invariant evaluator when its weight is non-zero to avoid unnecessary overhead
52+ default_weights = {
53+ "syntax" : 0.25 ,
54+ "runtime" : 0.25 ,
55+ "trace" : 0.25 ,
56+ "invariant" : 0.25 ,
57+ }
58+ self .phase_weights = default_weights
59+ if phase_weights :
60+ self .phase_weights .update (phase_weights )
61+ self .trace_enabled_tasks = set (trace_enabled_tasks or self .TRACE_ENABLED_TASKS_DEFAULT )
62+ self .invariant_eval = None
63+ if self .phase_weights .get ("invariant" , 0 ) > 0 :
64+ self .invariant_eval = ManualInvariantEvaluator (tlc_timeout = invariant_timeout )
3865
3966 def evaluate (self , generation_result , task , method_name : str , model_name : str ) -> EvaluationOutcome :
4067 """Run compilation + runtime evaluation for a generated spec."""
4168 errors : List [str ] = []
4269 runtime_result = None
70+ invariant_result = None
71+ trace_result = None
72+ phase_scores = {k : 0.0 for k in ["syntax" , "runtime" , "trace" , "invariant" ]}
73+ trace_enabled = self ._is_trace_enabled (task .task_name )
74+ normalized_weights = self ._build_normalized_weights (trace_enabled )
4375
4476 comp_result = self .compilation_eval .evaluate (
4577 generation_result ,
@@ -48,39 +80,165 @@ def evaluate(self, generation_result, task, method_name: str, model_name: str) -
4880 model_name ,
4981 task .spec_module ,
5082 )
83+ phase_scores ["syntax" ] = self ._compute_syntax_score (comp_result )
5184
5285 if comp_result .overall_success :
86+ # Runtime coverage (phase 2)
5387 runtime_result = self .runtime_eval .evaluate (
5488 generation_result ,
5589 task .task_name ,
5690 method_name ,
5791 model_name ,
5892 task .spec_module ,
5993 )
60- success = runtime_result . overall_success
61- if not success :
94+ phase_scores [ "runtime" ] = self . _compute_runtime_score ( runtime_result )
95+ if not runtime_result . overall_success :
6296 errors .append (f"Runtime: { getattr (runtime_result , 'error_message' , 'failed' )} " )
97+
98+ # Invariant verification (phase 4) – optional, controlled by weight
99+ if self .invariant_eval and phase_scores ["syntax" ] == 1.0 and phase_scores ["runtime" ] == 1.0 :
100+ try :
101+ invariant_result = self .invariant_eval .evaluate (
102+ generation_result ,
103+ task .task_name ,
104+ method_name ,
105+ model_name ,
106+ task .spec_module ,
107+ )
108+ phase_scores ["invariant" ] = self ._compute_invariant_score (invariant_result )
109+ if not invariant_result .overall_success :
110+ errors .append (f"Invariant: { getattr (invariant_result , 'model_checking_error' , 'failed' )} " )
111+ except Exception as exc : # pylint: disable=broad-exception-caught
112+ errors .append (f"Invariant: { exc } " )
113+ invariant_result = None
114+ # Trace validation placeholder – only run when enabled and earlier phases are perfect
115+ if (
116+ trace_enabled
117+ and normalized_weights .get ("trace" , 0 ) > 0
118+ and phase_scores ["syntax" ] == 1.0
119+ and phase_scores ["runtime" ] == 1.0
120+ ):
121+ trace_result = self ._run_trace_validation (
122+ generation_result ,
123+ task ,
124+ method_name ,
125+ model_name ,
126+ )
127+ phase_scores ["trace" ] = self ._compute_trace_score (trace_result )
128+ if not getattr (trace_result , "overall_success" , False ):
129+ error_msg = getattr (trace_result , "trace_validation_error" , None ) or "Trace validation failed"
130+ errors .append (f"Trace: { error_msg } " )
63131 else :
64132 success = False
65133 errors .extend ([f"Compilation: { err } " for err in comp_result .syntax_errors + comp_result .semantic_errors ])
66134
135+ # Overall success requires all executed phases to pass
136+ success = (
137+ comp_result .overall_success
138+ and (runtime_result .overall_success if runtime_result else True )
139+ and (invariant_result .overall_success if invariant_result else True )
140+ )
141+
67142 return EvaluationOutcome (
68143 compilation = comp_result ,
69144 runtime = runtime_result ,
145+ invariant = invariant_result ,
146+ trace = trace_result ,
70147 success = success ,
71148 errors = errors ,
149+ phase_scores = phase_scores ,
150+ normalized_weights = normalized_weights ,
72151 )
73152
74- @staticmethod
75- def final_score (outcome : Optional [EvaluationOutcome ]) -> float :
76- """Compute the final score from compilation/runtime outcomes."""
153+ def compute_score (self , outcome : Optional [EvaluationOutcome ]) -> float :
154+ """Compute weighted final score across phases."""
77155 if outcome is None :
78156 return 0.0
79157
80- comp_result = outcome .compilation
81- runtime_result = outcome .runtime
158+ total = 0.0
159+ weights = outcome .normalized_weights or self ._build_normalized_weights (True )
160+ for phase , weight in weights .items ():
161+ total += weight * outcome .phase_scores .get (phase , 0.0 )
162+ return total
82163
83- comp_score = 0.5 if getattr (comp_result , "overall_success" , False ) else 0.0
84- runtime_score = 0.5 if (comp_score > 0 and runtime_result and getattr (runtime_result , "overall_success" , False )) else 0.0
164+ @staticmethod
165+ def _compute_syntax_score (comp_result ) -> float :
166+ """Compute syntax score with action-level granularity when available."""
167+ if comp_result is None :
168+ return 0.0
169+ if getattr (comp_result , "overall_success" , False ):
170+ return 1.0
171+ action_rate = comp_result .action_success_rate
172+ return max (0.0 , min (1.0 , action_rate ))
85173
86- return comp_score + runtime_score
174+ @staticmethod
175+ def _compute_runtime_score (runtime_result ) -> float :
176+ """Compute runtime score using coverage metric when available."""
177+ if runtime_result is None or not getattr (runtime_result , "overall_success" , False ):
178+ return 0.0
179+ coverage = 0.0
180+ if hasattr (runtime_result , "custom_data" ):
181+ coverage = runtime_result .custom_data .get ("runtime_coverage_score" , 0.0 )
182+ return max (0.0 , min (1.0 , coverage ))
183+
184+ @staticmethod
185+ def _compute_invariant_score (invariant_result ) -> float :
186+ """Compute invariant score based on pass ratio."""
187+ if invariant_result is None or not getattr (invariant_result , "custom_data" , None ):
188+ return 0.0
189+ custom = invariant_result .custom_data or {}
190+ passed = custom .get ("passed_invariants" , 0 )
191+ total = custom .get ("total_invariants" , 0 )
192+ if total > 0 :
193+ return max (0.0 , min (1.0 , passed / total ))
194+ return 1.0 if getattr (invariant_result , "overall_success" , False ) else 0.0
195+
196+ @staticmethod
197+ def _compute_trace_score (trace_result ) -> float :
198+ """Compute trace score as pass/fail."""
199+ if trace_result is None :
200+ return 0.0
201+ return 1.0 if getattr (trace_result , "overall_success" , False ) else 0.0
202+
203+ def _is_trace_enabled (self , task_name : str ) -> bool :
204+ """Return whether trace validation is enabled for the given task."""
205+ return task_name in self .trace_enabled_tasks
206+
207+ def _build_normalized_weights (self , trace_enabled : bool ) -> Dict [str , float ]:
208+ """Return per-task normalized weights, dropping trace weight when disabled."""
209+ active_weights = dict (self .phase_weights )
210+ if not trace_enabled :
211+ active_weights .pop ("trace" , None )
212+ total_weight = sum (active_weights .values ()) or 1.0
213+ return {k : v / total_weight for k , v in active_weights .items ()}
214+
215+ def _run_trace_validation (self , generation_result , task , method_name : str , model_name : str ):
216+ """Execute trace validation metric using the generated spec."""
217+ try :
218+ with tempfile .TemporaryDirectory (prefix = "sysmo_trace_" ) as tmpdir :
219+ tmpdir_path = Path (tmpdir )
220+ spec_path = tmpdir_path / f"{ task .spec_module } .tla"
221+ spec_path .write_text (generation_result .generated_text , encoding = "utf-8" )
222+
223+ cfg_path = tmpdir_path / f"{ task .spec_module } .cfg"
224+ cfg_path .write_text (f"SPECIFICATION { task .spec_module } \n " , encoding = "utf-8" )
225+
226+ trace_eval = TraceValidationEvaluator (
227+ spec_dir = str (tmpdir_path ),
228+ traces_dir = "data/sys_traces" ,
229+ with_exist_traces = 20 , # Prefer existing traces to avoid regeneration when available
230+ model_name = "claude" ,
231+ )
232+ trace_result = trace_eval .evaluate (
233+ task_name = task .task_name ,
234+ config = {},
235+ spec_file_path = str (spec_path ),
236+ config_file_path = str (cfg_path ),
237+ )
238+ return trace_result
239+ except Exception as exc : # pylint: disable=broad-exception-caught
240+ class TraceResult :
241+ overall_success = False
242+ trace_validation_error = str (exc )
243+
244+ return TraceResult ()
0 commit comments