@@ -10,9 +10,6 @@ namespace Diverge
1010/- Automating the generation of the encoding and the proofs so as to use nice
1111 syntactic sugar. -/
1212
13- syntax (name := divergentDef)
14- declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command
15-
1613open Lean Elab Term Meta Primitives Lean.Meta
1714open Utils
1815
@@ -1389,17 +1386,43 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC
13891386 else return ()
13901387 catch _ => s.restore
13911388
1392- -- The following two functions are copy-pasted from Lean.Elab.MutualDef
1393-
1389+ -- The following three functions are copy-pasted from Lean.Elab.MutualDef.lean
13941390open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues
13951391 instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef
13961392
1393+ -- Copy/pasted from Lean.Elab.Term.withHeaderSecVars (because the definition is private)
1394+ private def Term.withHeaderSecVars {α} (vars : Array Expr) (includedVars : List Name) (headers : Array DefViewElabHeader)
1395+ (k : Array Expr → TermElabM α) : TermElabM α := do
1396+ let (_, used) ← collectUsed.run {}
1397+ let (lctx, localInsts, vars) ← removeUnused vars used
1398+ withLCtx lctx localInsts <| k vars
1399+ where
1400+ collectUsed : StateRefT CollectFVars.State MetaM Unit := do
1401+ -- directly referenced in headers
1402+ headers.forM (·.type.collectFVars)
1403+ -- included by `include`
1404+ vars.forM fun var => do
1405+ let ldecl ← getFVarLocalDecl var
1406+ if includedVars.contains ldecl.userName then
1407+ modify (·.add ldecl.fvarId)
1408+ -- transitively referenced
1409+ get >>= (·.addDependencies) >>= set
1410+ -- instances (`addDependencies` unnecessary as by definition they may only reference variables
1411+ -- already included)
1412+ vars.forM fun var => do
1413+ let ldecl ← getFVarLocalDecl var
1414+ let st ← get
1415+ if ldecl.binderInfo.isInstImplicit && (← getFVars ldecl.type).all st.fvarSet.contains then
1416+ modify (·.add ldecl.fvarId)
1417+ getFVars (e : Expr) : MetaM (Array FVarId) :=
1418+ (·.2 .fvarIds) <$> e.collectFVars.run {}
1419+
13971420-- Comes from Term.isExample
13981421def isExample (views : Array DefView) : Bool :=
13991422 views.any (·.kind.isExample)
14001423
14011424open Language in
1402- def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit :=
1425+ def Term.elabMutualDef (vars : Array Expr) (includedVars : List Name) ( views : Array DefView) : TermElabM Unit :=
14031426 if isExample views then
14041427 withoutModifyingEnv do
14051428 -- save correct environment in info tree
@@ -1418,7 +1441,7 @@ where
14181441 withFunLocalDecls headers fun funFVars => do
14191442 for view in views, funFVar in funFVars do
14201443 addLocalVarInfo view.declId funFVar
1421- -- Modification 1:
1444+ -- MODIFICATION 1:
14221445 -- Add fake use site to prevent "unused variable" warning (if the
14231446 -- function is actually not recursive, Lean would print this warning).
14241447 -- Remark: we could detect this case and encode the function without
@@ -1428,7 +1451,7 @@ where
14281451 addTermInfo' view.declId funFVar
14291452 let values ←
14301453 try
1431- let values ← elabFunValues headers
1454+ let values ← elabFunValues headers vars includedVars
14321455 Term.synthesizeSyntheticMVarsNoPostponing
14331456 values.mapM (instantiateMVars ·)
14341457 catch ex =>
@@ -1438,18 +1461,23 @@ where
14381461 let letRecsToLift ← getLetRecsToLift
14391462 let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
14401463 checkLetRecsToLiftTypes funFVars letRecsToLift
1441- withUsed vars headers values letRecsToLift fun vars => do
1464+ ( if headers.all (·.kind.isTheorem) && !deprecated.oldSectionVars.get (← getOptions) then withHeaderSecVars vars includedVars headers else withUsed vars headers values letRecsToLift) fun vars => do
14421465 let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
14431466 for preDef in preDefs do
14441467 trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n {preDef.value}"
14451468 let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs
14461469 let preDefs ← instantiateMVarsAtPreDecls preDefs
1470+ let preDefs ← shareCommonPreDefs preDefs
14471471 let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
14481472 for preDef in preDefs do
14491473 trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n {preDef.value}"
14501474 checkForHiddenUnivLevels allUserLevelNames preDefs
1451- addPreDefinitions preDefs -- Modification 2: we use our custom function here
1475+ addPreDefinitions preDefs -- MODIFICATION 2: we use our custom function here
14521476 processDeriving headers
1477+ for view in views, header in headers do
1478+ -- NOTE: this should be the full `ref`, and thus needs to be done after any snapshotting
1479+ -- that depends only on a part of the ref
1480+ addDeclarationRanges header.declName view.ref
14531481
14541482 processDeriving (headers : Array DefViewElabHeader) := do
14551483 for header in headers, view in views do
@@ -1460,22 +1488,61 @@ where
14601488 unless (← processDefDeriving className header.declName) do
14611489 throwError "failed to synthesize instance '{className}' for '{header.declName}'"
14621490
1491+ #check Command.elabMutualDef
1492+
1493+ -- Copy/pasted from Lean.Elab.MutualDef
14631494open Command in
1495+ open Language in
14641496def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
1465- let views ← ds.mapM fun d => do
1466- let `($mods:declModifiers divergent def $id :declId $sig:optDeclSig $val:declVal) := d
1467- | throwUnsupportedSyntax
1468- let modifiers ← elabModifiers mods
1469- let (binders, type) := expandOptDeclSig sig
1470- let deriving ? := none
1471- let headerRef := Syntax.missing -- Not sure what to put here
1472- pure { ref := d, kind := DefKind.def, headerRef, modifiers,
1473- declId := id, binders, type? := type, value := val, deriving ? }
1474- runTermElabM fun vars => Term.elabMutualDef vars views
1497+ let opts ← getOptions
1498+ withAlwaysResolvedPromises ds.size fun headerPromises => do
1499+ let snap? := (← read).snap?
1500+ let mut views := #[]
1501+ let mut defs := #[]
1502+ let mut reusedAllHeaders := true
1503+ for h : i in [0 :ds.size], headerPromise in headerPromises do
1504+ let d := ds[i]
1505+ let modifiers ← elabModifiers d[0 ]
1506+ if ds.size > 1 && modifiers.isNonrec then
1507+ throwErrorAt d "invalid use of 'nonrec' modifier in 'mutual' block"
1508+ let mut view ← mkDefView modifiers d[2 ] -- MODIFICATION: changed the index to 2
1509+ let fullHeaderRef := mkNullNode #[d[0 ], view.headerRef]
1510+ if let some snap := snap? then
1511+ view := { view with headerSnap? := some {
1512+ old? := do
1513+ -- transitioning from `Context.snap?` to `DefView.headerSnap?` invariant: if the
1514+ -- elaboration context and state are unchanged, and the syntax of this as well as all
1515+ -- previous headers is unchanged, then the elaboration result for this header (which
1516+ -- includes state from elaboration of previous headers!) should be unchanged.
1517+ guard reusedAllHeaders
1518+ let old ← snap.old?
1519+ -- blocking wait, `HeadersParsedSnapshot` (and hopefully others) should be quick
1520+ let old ← old.val.get.toTyped? DefsParsedSnapshot
1521+ let oldParsed ← old.defs[i]?
1522+ guard <| fullHeaderRef.eqWithInfoAndTraceReuse opts oldParsed.fullHeaderRef
1523+ -- no syntax guard to store, we already did the necessary checks
1524+ return ⟨.missing, oldParsed.headerProcessedSnap⟩
1525+ new := headerPromise
1526+ } }
1527+ defs := defs.push {
1528+ fullHeaderRef
1529+ headerProcessedSnap := { range? := d.getRange?, task := headerPromise.result }
1530+ }
1531+ reusedAllHeaders := reusedAllHeaders && view.headerSnap?.any (·.old?.isSome)
1532+ views := views.push view
1533+ if let some snap := snap? then
1534+ -- no non-fatal diagnostics at this point
1535+ snap.new.resolve <| .ofTyped { defs, diagnostics := .empty : DefsParsedSnapshot }
1536+ let includedVars := (← getScope).includedVars
1537+ runTermElabM fun vars => Term.elabMutualDef vars includedVars views
1538+
1539+ syntax (name := divergentDef)
1540+ declModifiers "divergent" Lean.Parser.Command.definition : command
14751541
14761542-- Special command so that we don't fall back to the built-in mutual when we produce an error.
14771543local syntax "_divergent" Parser.Command.mutual : command
1478- elab_rules : command | `(_divergent mutual $decls* end ) => Command.elabMutualDef decls
1544+ elab_rules : command
1545+ | `(_divergent mutual $decls* end ) => Command.elabMutualDef decls
14791546
14801547macro_rules
14811548 | `(mutual $decls* end ) => do
@@ -1501,6 +1568,8 @@ namespace Tests
15011568
15021569 /- Some examples of partial functions -/
15031570
1571+ -- set_option trace.Diverge true
1572+ -- set_option pp.rawOnError true
15041573 --set_option trace.Diverge.def true
15051574 --set_option trace.Diverge.def.genBody true
15061575 --set_option trace.Diverge.def.valid true
0 commit comments