Skip to content

Commit a6d50a6

Browse files
nomeataeric-wieser
andauthored
fix: Meta.Closure: topologically sort abstracted vars (#10926)
This PR topologically sorts abstracted vars in `Meta.Closure.mkValueTypeClosure` if MVars are being abstracted. Fixes #10705 --------- Co-authored-by: Eric Wieser <[email protected]>
1 parent dec0076 commit a6d50a6

File tree

2 files changed

+85
-2
lines changed

2 files changed

+85
-2
lines changed

src/Lean/Meta/Closure.lean

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ module
88
prelude
99
public import Lean.Meta.Check
1010
public import Lean.Meta.Tactic.AuxLemma
11+
import Lean.Util.ForEachExpr
1112

1213
public section
1314

@@ -349,9 +350,67 @@ def mkValueTypeClosureAux (type : Expr) (value : Expr) : ClosureM (Expr × Expr)
349350
process
350351
pure (type, value)
351352

353+
private structure TopoSort where
354+
tempMark : FVarIdHashSet := {}
355+
doneMark : FVarIdHashSet := {}
356+
newDecls : Array LocalDecl := #[]
357+
newArgs : Array Expr := #[]
358+
359+
/--
360+
By construction, the `newLocalDecls` for fvars are in dependency order, but those for MVars may not be,
361+
and need to be interleaved appropriately. This we do a “topological insertion sort” of these.
362+
We care about efficiency for the common case of many fvars and no mvars.
363+
-/
364+
private partial def sortDecls (sortedDecls : Array LocalDecl) (sortedArgs : Array Expr)
365+
(toSortDecls : Array LocalDecl) (toSortArgs : Array Expr) : CoreM (Array LocalDecl × Array Expr):= do
366+
assert! sortedDecls.size = sortedArgs.size
367+
assert! toSortDecls.size = toSortArgs.size
368+
if toSortDecls.isEmpty then
369+
return (sortedDecls, sortedArgs)
370+
trace[Meta.Closure] "MVars to abstract, topologically sorting the abstracted variables"
371+
let mut m : Std.HashMap FVarId (LocalDecl × Expr) := {}
372+
for decl in sortedDecls, arg in sortedArgs do
373+
m := m.insert decl.fvarId (decl, arg)
374+
for decl in toSortDecls, arg in toSortArgs do
375+
m := m.insert decl.fvarId (decl, arg)
376+
377+
let rec visit (fvarId : FVarId) : StateT TopoSort CoreM Unit := do
378+
let some (decl, arg) := m.get? fvarId | return
379+
if (← get).doneMark.contains decl.fvarId then
380+
return ()
381+
trace[Meta.Closure] "Sorting decl {mkFVar decl.fvarId} : {decl.type}"
382+
if (← get).tempMark.contains decl.fvarId then
383+
throwError "cycle detected in sorting abstracted variables"
384+
assert! !decl.isLet (allowNondep := true) -- should all be cdecls
385+
modify fun s => { s with tempMark := s.tempMark.insert decl.fvarId }
386+
let type := decl.type
387+
type.forEach' fun e => do
388+
if e.hasFVar then
389+
if e.isFVar then
390+
visit e.fvarId!
391+
return true
392+
else
393+
return false
394+
modify fun s => { s with
395+
newDecls := s.newDecls.push decl
396+
newArgs := s.newArgs.push arg
397+
doneMark := s.doneMark.insert decl.fvarId
398+
}
399+
400+
let s₀ := { newDecls := .emptyWithCapacity m.size, newArgs := .emptyWithCapacity m.size }
401+
StateT.run' (s := s₀) do
402+
for decl in sortedDecls do
403+
visit decl.fvarId
404+
for decl in toSortDecls do
405+
visit decl.fvarId
406+
let {newDecls, newArgs, .. } ← get
407+
trace[Meta.Closure] "Sorted fvars: {newDecls.map (mkFVar ·.fvarId)}"
408+
return (newDecls, newArgs)
409+
352410
def mkValueTypeClosure (type : Expr) (value : Expr) (zetaDelta : Bool) : MetaM MkValueTypeClosureResult := do
353411
let ((type, value), s) ← ((mkValueTypeClosureAux type value).run { zetaDelta }).run {}
354-
let newLocalDecls := s.newLocalDecls.reverse ++ s.newLocalDeclsForMVars
412+
let (newLocalDecls, newArgs) ← sortDecls s.newLocalDecls.reverse s.exprFVarArgs.reverse
413+
s.newLocalDeclsForMVars s.exprMVarArgs
355414
let newLetDecls := s.newLetDecls.reverse
356415
let type := mkForall newLocalDecls (mkForall newLetDecls type)
357416
let value := mkLambda newLocalDecls (mkLambda newLetDecls value)
@@ -360,7 +419,7 @@ def mkValueTypeClosure (type : Expr) (value : Expr) (zetaDelta : Bool) : MetaM M
360419
value := value,
361420
levelParams := s.levelParams,
362421
levelArgs := s.levelArgs,
363-
exprArgs := s.exprFVarArgs.reverse ++ s.exprMVarArgs
422+
exprArgs := newArgs
364423
}
365424

366425
end Closure
@@ -396,4 +455,7 @@ def mkAuxTheorem (type : Expr) (value : Expr) (zetaDelta : Bool := false) (kind?
396455
let name ← mkAuxLemma (kind? := kind?) (cache := cache) result.levelParams.toList result.type result.value
397456
return mkAppN (mkConst name result.levelArgs.toList) result.exprArgs
398457

458+
builtin_initialize
459+
registerTraceClass `Meta.Closure
460+
399461
end Lean.Meta

tests/lean/run/issue10705.lean

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import Lean
2+
3+
open Lean Meta
4+
5+
-- set_option trace.Meta.Closure true
6+
7+
/--
8+
info: before: h.2
9+
---
10+
info: after: _proof_2 ?m.1 h
11+
-/
12+
#guard_msgs(pass trace, all) in
13+
run_meta do
14+
have l := ← Lean.Meta.mkFreshExprMVar (mkConst ``True) (kind := .syntheticOpaque)
15+
let ty := mkAnd (mkConst ``True) (.letE `foo (mkConst ``True) l (mkConst ``True) false)
16+
withLocalDecl `h default ty fun x => do
17+
let e := mkProj ``And 1 x
18+
Lean.logInfo m!"before: {e}"
19+
-- works fine without this line
20+
let e ← Lean.Meta.mkAuxTheorem (mkConst ``True) e (zetaDelta := true) -- or false, not really relevant
21+
Lean.logInfo m!"after: {e}"

0 commit comments

Comments
 (0)