Skip to content

Commit b17f84f

Browse files
committed
Move dischargeObligationIncremental to SMTUtils for Imperative reuse
1 parent b4e6b85 commit b17f84f

2 files changed

Lines changed: 83 additions & 54 deletions

File tree

Strata/DL/Imperative/SMTUtils.lean

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,79 @@ def addLocationInfo {P : PureExpr} [BEq P.Ident]
292292
Strata.SMT.Solver.setInfoString message.fst message.snd
293293
| .none => pure ()
294294

295+
/-- Result of encoding a proof obligation against an `AbstractSolver`.
296+
Returned by the encoder callback passed to `dischargeObligationIncremental`,
297+
consumed by the check-sat orchestration. -/
298+
structure EncodedObligation where
299+
obligationId : Strata.SMT.Term
300+
assumptionIds : List String
301+
estate : Strata.SMT.EncoderState
302+
303+
/-- Discharge a proof obligation using a live (incremental) SMT solver.
304+
The encoder callback runs against the spawned solver to emit declarations
305+
and assertions; this helper orchestrates check-sat calls and model parsing. -/
306+
def dischargeObligationIncremental {P : PureExpr} [ToFormat P.Ident] [BEq P.Ident]
307+
(encodeDecl : Strata.SMT.AbstractSolver Strata.SMT.Term Strata.SMT.TermType
308+
Strata.SMT.IncrementalSolverM →
309+
Strata.SMT.IncrementalSolverM EncodedObligation)
310+
(typedVarToSMTFn : P.Ident → P.Ty → Except Format (String × Strata.SMT.TermType))
311+
(vars : List P.TypedIdent)
312+
(smtsolver : String) (solverFlags : Array String)
313+
(satisfiabilityCheck validityCheck : Bool) :
314+
IO (Except SolverError (Result P.Ident × Result P.Ident × Strata.SMT.EncoderState)) := do
315+
let solverState ← Strata.SMT.IncrementalSolver.spawn smtsolver solverFlags
316+
let action : Strata.SMT.IncrementalSolverM
317+
(Except SolverError (Result P.Ident × Result P.Ident × Strata.SMT.EncoderState)) := do
318+
let solver := Strata.SMT.IncrementalSolver.mkIncrementalSolver
319+
let { obligationId, assumptionIds, estate } ← encodeDecl solver
320+
let varIds := assumptionIds.map fun id => Strata.SMT.Term.var ⟨id, .bool⟩
321+
let getModelForVars : Strata.SMT.IncrementalSolverM (Model P.Ident) := do
322+
if varIds.isEmpty then return []
323+
try
324+
match ← solver.getValue varIds with
325+
| .ok pairs =>
326+
match pairs with
327+
| [(.prim (.string rawOutput), _)] =>
328+
let rawModel ← parseModelDDM rawOutput
329+
match processModel typedVarToSMTFn vars rawModel estate with
330+
| .ok model => return model
331+
| .error _ => return []
332+
| _ => return []
333+
| .error _ => return []
334+
catch _ => return []
335+
let decisionToResult (decision : Except String Strata.SMT.Decision) :
336+
Strata.SMT.IncrementalSolverM (Result P.Ident) := do
337+
match decision with
338+
| .ok .sat => return .sat (← getModelForVars)
339+
| .ok .unknown =>
340+
let model ← getModelForVars
341+
return if model.isEmpty then .unknown else .unknown (some model)
342+
| .ok .unsat => return .unsat
343+
| .error msg => return .err msg
344+
let unwrap {α : Type} (label : String) (r : Except String α) : Strata.SMT.IncrementalSolverM α :=
345+
match r with
346+
| .ok a => return a
347+
| .error msg => throw (IO.userError s!"{label}: {msg}")
348+
let bothChecks := satisfiabilityCheck && validityCheck
349+
let mut satResult : Result P.Ident := .unknown
350+
let mut valResult : Result P.Ident := .unknown
351+
if bothChecks then
352+
satResult ← decisionToResult (← solver.checkSatAssuming [obligationId])
353+
let negObligation ← unwrap "mkNot" (← solver.mkNot obligationId)
354+
valResult ← decisionToResult (← solver.checkSatAssuming [negObligation])
355+
else
356+
if satisfiabilityCheck then
357+
unwrap "assert" (← solver.assert obligationId)
358+
satResult ← decisionToResult (← solver.checkSat)
359+
else if validityCheck then
360+
let negObligation ← unwrap "mkNot" (← solver.mkNot obligationId)
361+
unwrap "assert" (← solver.assert negObligation)
362+
valResult ← decisionToResult (← solver.checkSat)
363+
solver.close
364+
return .ok (satResult, valResult, estate)
365+
let (result, _) ← action.run solverState
366+
return result
367+
295368
/--
296369
Writes the proof obligation to file, discharge the obligation using SMT solver,
297370
and parse the output of the SMT solver.

Strata/Languages/Core/Verifier.lean

Lines changed: 10 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -557,8 +557,7 @@ def dischargeObligationIncremental
557557
(_label : String)
558558
(varDefinitions : List VarDefinition := [])
559559
(varDeclarations : List VarDeclaration := [])
560-
: IO (Except Imperative.SMT.SolverError (SMT.Result × SMT.Result × EncoderState)) :=
561-
open _root_.Strata.SMT.IncrementalSolver in do
560+
: IO (Except Imperative.SMT.SolverError (SMT.Result × SMT.Result × EncoderState)) := do
562561
let baseFlags := getSolverFlags options
563562
let needsIncremental := satisfiabilityCheck && validityCheck
564563
let solverSpecificFlags := match options.solver with
@@ -567,68 +566,25 @@ def dischargeObligationIncremental
567566
if needsIncremental && !baseFlags.contains "--incremental" then
568567
base ++ #["--incremental"]
569568
else base
570-
| "z3" => #["-in"] -- z3 reads from stdin with -in
569+
| "z3" => #["-in"]
571570
| _ => #[]
572571
let allFlags := solverSpecificFlags ++ baseFlags
573-
let solverState ← spawn options.solver allFlags
574-
let action : _root_.Strata.SMT.IncrementalSolverM (Except Imperative.SMT.SolverError (SMT.Result × SMT.Result × EncoderState)) := do
575-
let solver := _root_.Strata.SMT.IncrementalSolver.mkIncrementalSolver
576-
-- Solver-specific prelude (options like smt.mbqi, auto_config)
577-
let prelude : _root_.Strata.SMT.IncrementalSolverM Unit := match options.solver with
572+
let encodeDecl (solver : Strata.SMT.AbstractSolver Term TermType
573+
Strata.SMT.IncrementalSolverM) :
574+
Strata.SMT.IncrementalSolverM Imperative.SMT.EncodedObligation := do
575+
let prelude : Strata.SMT.IncrementalSolverM Unit := match options.solver with
578576
| "z3" => do
579577
solver.setOption "smt.mbqi" "false"
580578
solver.setOption "auto_config" "false"
581579
| _ => pure ()
582-
-- Encode all declarations and assertions through the AbstractSolver API.
583580
let (obligationId, ids, estate) ←
584581
_root_.Strata.SMT.Encoder.encodeDeclarationsAbstract solver ctx prelude
585582
assumptionTerms obligationTerm
586583
(varDefinitions := varDefinitions) (varDeclarations := varDeclarations)
587-
-- Variable terms for getValue
588-
let varIds := ids.map fun id => Term.var ⟨id, .bool⟩
589-
-- Helper to get model via solver.getValue and parse it.
590-
-- Called only when the decision is SAT or UNKNOWN.
591-
let getModelForVars : _root_.Strata.SMT.IncrementalSolverM (Imperative.SMT.Model Expression.Ident) := do
592-
if varIds.isEmpty then return []
593-
match ← solver.getValue varIds with
594-
| .ok pairs =>
595-
match pairs with
596-
| [(.prim (.string rawOutput), _)] =>
597-
let rawModel ← Imperative.SMT.parseModelDDM rawOutput
598-
match Imperative.SMT.processModel (typedVarToSMTFn ctx) vars rawModel estate with
599-
| .ok model => return model
600-
| .error _ => return []
601-
| _ => return []
602-
| .error _ => return []
603-
-- Issue check-sat commands through the AbstractSolver API.
604-
let decisionToResult (decision : Except String Decision)
605-
: _root_.Strata.SMT.IncrementalSolverM (Imperative.SMT.Result Expression.Ident) := do
606-
match decision with
607-
| .ok .sat => return .sat (← getModelForVars)
608-
| .ok .unknown =>
609-
let model ← getModelForVars
610-
return if model.isEmpty then .unknown else .unknown (some model)
611-
| .ok .unsat => return .unsat
612-
| .error msg => return .err msg
613-
let bothChecks := satisfiabilityCheck && validityCheck
614-
let mut satResult : Imperative.SMT.Result Expression.Ident := .unknown
615-
let mut valResult : Imperative.SMT.Result Expression.Ident := .unknown
616-
if bothChecks then
617-
satResult ← decisionToResult (← solver.checkSatAssuming [obligationId])
618-
let negObligation ← _root_.Strata.SMT.Encoder.unwrap "mkNot" (← solver.mkNot obligationId)
619-
valResult ← decisionToResult (← solver.checkSatAssuming [negObligation])
620-
else
621-
if satisfiabilityCheck then
622-
_root_.Strata.SMT.Encoder.unwrap "assert" (← solver.assert obligationId)
623-
satResult ← decisionToResult (← solver.checkSat)
624-
else if validityCheck then
625-
let negObligation ← _root_.Strata.SMT.Encoder.unwrap "mkNot" (← solver.mkNot obligationId)
626-
_root_.Strata.SMT.Encoder.unwrap "assert" (← solver.assert negObligation)
627-
valResult ← decisionToResult (← solver.checkSat)
628-
solver.close
629-
return .ok (satResult, valResult, estate)
630-
let (result, _) ← action.run solverState
631-
return result
584+
return { obligationId, assumptionIds := ids, estate }
585+
Imperative.SMT.dischargeObligationIncremental (P := Core.Expression)
586+
encodeDecl (typedVarToSMTFn ctx) vars options.solver allFlags
587+
satisfiabilityCheck validityCheck
632588

633589
end -- public section
634590
end Core.SMT

0 commit comments

Comments
 (0)