88prelude
99public import Lean.Meta.Check
1010public import Lean.Meta.Tactic.AuxLemma
11+ import Lean.Util.ForEachExpr
1112
1213public section
1314
@@ -349,9 +350,67 @@ def mkValueTypeClosureAux (type : Expr) (value : Expr) : ClosureM (Expr × Expr)
349350 process
350351 pure (type, value)
351352
353+ private structure TopoSort where
354+ tempMark : FVarIdHashSet := {}
355+ doneMark : FVarIdHashSet := {}
356+ newDecls : Array LocalDecl := #[]
357+ newArgs : Array Expr := #[]
358+
359+ /--
360+ By construction, the `newLocalDecls` for fvars are in dependency order, but those for MVars may not be,
361+ and need to be interleaved appropriately. This we do a “topological insertion sort” of these.
362+ We care about efficiency for the common case of many fvars and no mvars.
363+ -/
364+ private partial def sortDecls (sortedDecls : Array LocalDecl) (sortedArgs : Array Expr)
365+ (toSortDecls : Array LocalDecl) (toSortArgs : Array Expr) : CoreM (Array LocalDecl × Array Expr):= do
366+ assert! sortedDecls.size = sortedArgs.size
367+ assert! toSortDecls.size = toSortArgs.size
368+ if toSortDecls.isEmpty then
369+ return (sortedDecls, sortedArgs)
370+ trace[Meta.Closure] "MVars to abstract, topologically sorting the abstracted variables"
371+ let mut m : Std.HashMap FVarId (LocalDecl × Expr) := {}
372+ for decl in sortedDecls, arg in sortedArgs do
373+ m := m.insert decl.fvarId (decl, arg)
374+ for decl in toSortDecls, arg in toSortArgs do
375+ m := m.insert decl.fvarId (decl, arg)
376+
377+ let rec visit (fvarId : FVarId) : StateT TopoSort CoreM Unit := do
378+ let some (decl, arg) := m.get? fvarId | return
379+ if (← get).doneMark.contains decl.fvarId then
380+ return ()
381+ trace[Meta.Closure] "Sorting decl {mkFVar decl.fvarId} : {decl.type}"
382+ if (← get).tempMark.contains decl.fvarId then
383+ throwError "cycle detected in sorting abstracted variables"
384+ assert! !decl.isLet (allowNondep := true ) -- should all be cdecls
385+ modify fun s => { s with tempMark := s.tempMark.insert decl.fvarId }
386+ let type := decl.type
387+ type.forEach' fun e => do
388+ if e.hasFVar then
389+ if e.isFVar then
390+ visit e.fvarId!
391+ return true
392+ else
393+ return false
394+ modify fun s => { s with
395+ newDecls := s.newDecls.push decl
396+ newArgs := s.newArgs.push arg
397+ doneMark := s.doneMark.insert decl.fvarId
398+ }
399+
400+ let s₀ := { newDecls := .emptyWithCapacity m.size, newArgs := .emptyWithCapacity m.size }
401+ StateT.run' (s := s₀) do
402+ for decl in sortedDecls do
403+ visit decl.fvarId
404+ for decl in toSortDecls do
405+ visit decl.fvarId
406+ let {newDecls, newArgs, .. } ← get
407+ trace[Meta.Closure] "Sorted fvars: {newDecls.map (mkFVar ·.fvarId)}"
408+ return (newDecls, newArgs)
409+
352410def mkValueTypeClosure (type : Expr) (value : Expr) (zetaDelta : Bool) : MetaM MkValueTypeClosureResult := do
353411 let ((type, value), s) ← ((mkValueTypeClosureAux type value).run { zetaDelta }).run {}
354- let newLocalDecls := s.newLocalDecls.reverse ++ s.newLocalDeclsForMVars
412+ let (newLocalDecls, newArgs) ← sortDecls s.newLocalDecls.reverse s.exprFVarArgs.reverse
413+ s.newLocalDeclsForMVars s.exprMVarArgs
355414 let newLetDecls := s.newLetDecls.reverse
356415 let type := mkForall newLocalDecls (mkForall newLetDecls type)
357416 let value := mkLambda newLocalDecls (mkLambda newLetDecls value)
@@ -360,7 +419,7 @@ def mkValueTypeClosure (type : Expr) (value : Expr) (zetaDelta : Bool) : MetaM M
360419 value := value,
361420 levelParams := s.levelParams,
362421 levelArgs := s.levelArgs,
363- exprArgs := s.exprFVarArgs.reverse ++ s.exprMVarArgs
422+ exprArgs := newArgs
364423 }
365424
366425end Closure
@@ -396,4 +455,7 @@ def mkAuxTheorem (type : Expr) (value : Expr) (zetaDelta : Bool := false) (kind?
396455 let name ← mkAuxLemma (kind? := kind?) (cache := cache) result.levelParams.toList result.type result.value
397456 return mkAppN (mkConst name result.levelArgs.toList) result.exprArgs
398457
458+ builtin_initialize
459+ registerTraceClass `Meta.Closure
460+
399461end Lean.Meta
0 commit comments