Skip to content

Commit f9dc776

Browse files
authored
feat: dedicated fix operator for well-founded recursion on Nat (#7965)
This PR lets recursive functions defined by well-founded recursion use a different `fix` function when the termination measure is of type `Nat`. This fix-point operator use structural recursion on “fuel”, initialized by the given measure, and is thus reasonable to reduce, e.g. in `by decide` proofs. Extra provisions are in place that the fixpoint operator only starts reducing when the fuel is fully known, to prevent “accidential” defeqs when the remaining fuel for the recursive calls match the initial fuel for that recursive argument. To opt-out, the idiom `termination_by (n,0)` can be used. We still use `@[irreducible]` as the default for such recursive definitions, to avoid unexpected `defeq` lemmas. Making these functions `@[semireducible]` by default showed performance regressions in lean. When the measure is of type `Nat`, the system will accept an explicit `@[semireducible]` without the usual warning. Fixes #5234. Fixes: #11181.
1 parent 1ae680c commit f9dc776

25 files changed

+381
-152
lines changed

src/Init/WF.lean

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,53 @@ end
439439

440440
end PSigma
441441

442+
namespace WellFounded
443+
444+
variable {α : Sort u}
445+
variable {motive : α → Sort v}
446+
variable (h : α → Nat)
447+
variable (F : (x : α) → ((y : α) → InvImage (· < ·) h y x → motive y) → motive x)
448+
449+
/-- Helper gadget that prevents reduction of `Nat.eager n` unless `n` evalutes to a ground term. -/
450+
def Nat.eager (n : Nat) : Nat :=
451+
if Nat.beq n n = true then n else n
452+
453+
theorem Nat.eager_eq (n : Nat) : Nat.eager n = n := ite_self n
454+
455+
/--
456+
A well-founded fixpoint operator specialized for `Nat`-valued measures. Given a measure `h`, it expects
457+
its higher order function argument `F` to invoke its argument only on values `y` that are smaller
458+
than `x` with regard to `h`.
459+
460+
In contrast to to `WellFounded.fix`, this fixpoint operator reduces on closed terms. (More precisely:
461+
when `h x` evalutes to a ground value)
462+
463+
-/
464+
def Nat.fix : (x : α) → motive x :=
465+
let rec go : ∀ (fuel : Nat) (x : α), (h x < fuel) → motive x :=
466+
Nat.rec
467+
(fun _ hfuel => (Nat.not_succ_le_zero _ hfuel).elim)
468+
(fun _ ih x hfuel => F x (fun y hy => ih y (Nat.lt_of_lt_of_le hy (Nat.le_of_lt_add_one hfuel))))
469+
fun x => go (Nat.eager (h x + 1)) x (Nat.eager_eq _ ▸ Nat.lt_add_one _)
470+
471+
protected theorem Nat.fix.go_congr (x : α) (fuel₁ fuel₂ : Nat) (h₁ : h x < fuel₁) (h₂ : h x < fuel₂) :
472+
Nat.fix.go h F fuel₁ x h₁ = Nat.fix.go h F fuel₂ x h₂ := by
473+
induction fuel₁ generalizing x fuel₂ with
474+
| zero => contradiction
475+
| succ fuel₁ ih =>
476+
cases fuel₂ with
477+
| zero => contradiction
478+
| succ fuel₂ =>
479+
exact congrArg (F x) (funext fun y => funext fun hy => ih y fuel₂ _ _ )
480+
481+
theorem Nat.fix_eq (x : α) :
482+
Nat.fix h F x = F x (fun y _ => Nat.fix h F y) := by
483+
unfold Nat.fix
484+
simp [Nat.eager_eq]
485+
exact congrArg (F x) (funext fun _ => funext fun _ => Nat.fix.go_congr ..)
486+
487+
end WellFounded
488+
442489
/--
443490
The `wfParam` gadget is used internally during the construction of recursive functions by
444491
wellfounded recursion, to keep track of the parameter for which the automatic introduction

src/Lean/Compiler/InlineAttrs.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ private def isValidMacroInline (declName : Name) : CoreM Bool := do
3737
let isRec (declName' : Name) : Bool :=
3838
isBRecOnRecursor env declName' ||
3939
declName' == ``WellFounded.fix ||
40+
declName' == ``WellFounded.Nat.fix ||
4041
declName' == declName ++ `_unary -- Auxiliary declaration created by `WF` module
4142
if Option.isSome <| info.value.find? fun e => e.isConst && isRec e.constName! then
4243
-- It contains a `brecOn` or `WellFounded.fix` application. So, it should be recursvie

src/Lean/Elab/PreDefinition/WF/Fix.lean

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,21 +237,32 @@ def solveDecreasingGoals (funNames : Array Name) (argsPacker : ArgsPacker) (decr
237237
Term.reportUnsolvedGoals remainingGoals
238238
instantiateMVars value
239239

240+
def isNatLtWF (wfRel : Expr) : MetaM (Option Expr) := do
241+
match_expr wfRel with
242+
| invImage _ β f wfRelβ =>
243+
unless (← isDefEq β (mkConst ``Nat)) do return none
244+
unless (← isDefEq wfRelβ (mkConst ``Nat.lt_wfRel)) do return none
245+
return f
246+
| _ => return none
247+
240248
def mkFix (preDef : PreDefinition) (prefixArgs : Array Expr) (argsPacker : ArgsPacker)
241249
(wfRel : Expr) (funNames : Array Name) (decrTactics : Array (Option DecreasingBy)) :
242250
TermElabM Expr := do
243251
let type ← instantiateForall preDef.type prefixArgs
244252
let (wfFix, varName) ← forallBoundedTelescope type (some 1) fun x type => do
245253
let x := x[0]!
254+
let varName ← x.fvarId!.getUserName -- See comment below.
246255
let α ← inferType x
247256
let u ← getLevel α
248257
let v ← getLevel type
249258
let motive ← mkLambdaFVars #[x] type
250-
let rel := mkProj ``WellFoundedRelation 0 wfRel
251-
let wf := mkProj ``WellFoundedRelation 1 wfRel
252-
let wf ← mkAppM `Lean.opaqueId #[wf]
253-
let varName ← x.fvarId!.getUserName -- See comment below.
254-
return (mkApp4 (mkConst ``WellFounded.fix [u, v]) α motive rel wf, varName)
259+
if let some measure ← isNatLtWF wfRel then
260+
return (mkApp3 (mkConst `WellFounded.Nat.fix [u, v]) α motive measure, varName)
261+
else
262+
let rel := mkProj ``WellFoundedRelation 0 wfRel
263+
let wf := mkProj ``WellFoundedRelation 1 wfRel
264+
let wf ← mkAppM `Lean.opaqueId #[wf]
265+
return (mkApp4 (mkConst ``WellFounded.fix [u, v]) α motive rel wf, varName)
255266
forallBoundedTelescope (← whnf (← inferType wfFix)).bindingDomain! (some 2) fun xs _ => do
256267
let x := xs[0]!
257268
-- Remark: we rename `x` here to make sure we preserve the variable name in the

src/Lean/Elab/PreDefinition/WF/Main.lean

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,19 +53,21 @@ def wfRecursion (docCtx : LocalContext × LocalInstances) (preDefs : Array PreDe
5353
-- No termination_by here, so use GuessLex to infer one
5454
guessLex preDefs unaryPreDefProcessed fixedParamPerms argsPacker
5555

56-
-- Warn about likely unwanted reducibility attributes
57-
preDefs.forM fun preDef =>
58-
preDef.modifiers.attrs.forM fun a => do
59-
if a.name = `reducible || a.name = `semireducible then
60-
logWarningAt a.stx s!"marking functions defined by well-founded recursion as `{a.name}` is not effective"
61-
6256
let preDefNonRec ← forallBoundedTelescope unaryPreDef.type fixedParamPerms.numFixed fun fixedArgs type => do
6357
let type ← whnfForall type
6458
unless type.isForall do
6559
throwError "wfRecursion: expected unary function type: {type}"
6660
let packedArgType := type.bindingDomain!
6761
elabWFRel (preDefs.map (·.declName)) unaryPreDef.declName fixedParamPerms fixedArgs argsPacker packedArgType wf fun wfRel => do
6862
trace[Elab.definition.wf] "wfRel: {wfRel}"
63+
let useNatRec := (← isNatLtWF wfRel).isSome
64+
-- Warn about likely unwanted reducibility attributes
65+
unless useNatRec do
66+
preDefs.forM fun preDef =>
67+
preDef.modifiers.attrs.forM fun a => do
68+
if a.name = `reducible || a.name = `semireducible then
69+
logWarningAt a.stx s!"marking functions defined by well-founded recursion as `{a.name}` is not effective"
70+
6971
let (value, envNew) ← withoutModifyingEnv' do
7072
addAsAxiom unaryPreDef
7173
let value ← mkFix unaryPreDefProcessed fixedArgs argsPacker wfRel (preDefs.map (·.declName)) (preDefs.map (·.termination.decreasingBy?))

src/Lean/Elab/PreDefinition/WF/Unfold.lean

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,14 @@ def rwFixEq (mvarId : MVarId) : MetaM MVarId := mvarId.withContext do
3636
-- lhs should be an application of the declNameNonrec, which unfolds to an
3737
-- application of fix in one step
3838
let some lhs' ← delta? lhs | throwError "rwFixEq: cannot delta-reduce {lhs}"
39-
let_expr WellFounded.fix _α _C _r _hwf F x := lhs'
40-
| throwTacticEx `rwFixEq mvarId "expected saturated fixed-point application in {lhs'}"
41-
let h := mkAppN (mkConst ``WellFounded.fix_eq lhs'.getAppFn.constLevels!) lhs'.getAppArgs
39+
let h ← match_expr lhs' with
40+
| WellFounded.fix _α _C _r _hwf _F _x =>
41+
pure <| mkAppN (mkConst ``WellFounded.fix_eq lhs'.getAppFn.constLevels!) lhs'.getAppArgs
42+
| WellFounded.Nat.fix _α _C _motive _F _x =>
43+
pure <| mkAppN (mkConst ``WellFounded.Nat.fix_eq lhs'.getAppFn.constLevels!) lhs'.getAppArgs
44+
| _ => throwTacticEx `rwFixEq mvarId m!"expected saturated fixed-point application in {lhs'}"
45+
let F := lhs'.appFn!.appArg!
46+
let x := lhs'.appArg!
4247

4348
-- We used to just rewrite with `fix_eq` and continue with whatever RHS that produces, but that
4449
-- would include more copies of `fix` resulting in large and confusing terms.

src/Lean/Meta/Tactic/FunInd.lean

Lines changed: 68 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -927,76 +927,84 @@ where doRealize (inductName : Name) := do
927927
-- to make sure that `target` indeed the last parameter
928928
let e := info.value
929929
let e ← lambdaTelescope e fun params body => do
930-
if body.isAppOfArity ``WellFounded.fix 5 then
930+
if body.isAppOfArity ``WellFounded.fix 5 || body.isAppOfArity ``WellFounded.Nat.fix 4 then
931931
forallBoundedTelescope (← inferType body) (some 1) fun xs _ => do
932932
unless xs.size = 1 do
933933
throwError "functional induction: Failed to eta-expand{indentExpr e}"
934934
mkLambdaFVars (params ++ xs) (mkAppN body xs)
935935
else
936936
pure e
937937
let (e', paramMask) ← lambdaTelescope e fun params funBody => MatcherApp.withUserNames params varNames do
938-
match_expr funBody with
939-
| [email protected] α _motive rel wf body target =>
940-
unless params.back! == target do
941-
throwError "functional induction: expected the target as last parameter{indentExpr e}"
942-
let fixedParamPerms := params.pop
943-
let motiveType ←
944-
if unfolding then
945-
withLocalDeclD `r (← instantiateForall info.type params) fun r =>
946-
mkForallFVars #[target, r] (.sort 0)
938+
unless funBody.isApp && funBody.appFn!.isApp do
939+
throwError "functional induction: unexpected body {funBody}"
940+
let body := funBody.appFn!.appArg!
941+
let target := funBody.appArg!
942+
unless params.back! == target do
943+
throwError "functional induction: expected the target as last parameter{indentExpr e}"
944+
let fixedParamPerms := params.pop
945+
let motiveType ←
946+
if unfolding then
947+
withLocalDeclD `r (← instantiateForall info.type params) fun r =>
948+
mkForallFVars #[target, r] (.sort 0)
949+
else
950+
mkForallFVars #[target] (.sort 0)
951+
withLocalDeclD `motive motiveType fun motive => do
952+
let fn := mkAppN (← mkConstWithLevelParams name) fixedParamPerms
953+
let isRecCall : Expr → Option Expr := fun e =>
954+
e.withApp fun f xs =>
955+
if f.isFVarOf motive.fvarId! && xs.size > 0 then
956+
mkApp fn xs[0]!
947957
else
948-
mkForallFVars #[target] (.sort 0)
949-
withLocalDeclD `motive motiveType fun motive => do
950-
let fn := mkAppN (← mkConstWithLevelParams name) fixedParamPerms
951-
let isRecCall : Expr → Option Expr := fun e =>
952-
e.withApp fun f xs =>
953-
if f.isFVarOf motive.fvarId! && xs.size > 0 then
954-
mkApp fn xs[0]!
955-
else
956-
none
958+
none
957959

958-
let motiveArg ←
959-
if unfolding then
960-
let motiveArg := mkApp2 motive target (mkAppN (← mkConstWithLevelParams name) params)
961-
mkLambdaFVars #[target] motiveArg
960+
let motiveArg ←
961+
if unfolding then
962+
let motiveArg := mkApp2 motive target (mkAppN (← mkConstWithLevelParams name) params)
963+
mkLambdaFVars #[target] motiveArg
964+
else
965+
pure motive
966+
967+
let e' ← match_expr funBody with
968+
| [email protected] α _motive rel wf _body _target =>
969+
let e' := .const ``WellFounded.fix [fix.constLevels![0]!, levelZero]
970+
pure <| mkApp4 e' α motiveArg rel wf
971+
| [email protected] α _motive measure _body _target =>
972+
let e' := .const `WellFounded.Nat.fix [fix.constLevels![0]!, levelZero]
973+
pure <| mkApp3 e' α motiveArg measure
974+
| _ =>
975+
if funBody.isAppOf ``WellFounded.fix || funBody.isAppOf `WellFounded.Nat.Fix then
976+
throwError "Function {name} defined via WellFounded.fix with unexpected arity {funBody.getAppNumArgs}:{indentExpr funBody}"
962977
else
963-
pure motive
964-
let e' := .const ``WellFounded.fix [fix.constLevels![0]!, levelZero]
965-
let e' := mkApp4 e' α motiveArg rel wf
966-
check e'
967-
let (body', mvars) ← M2.run do
968-
forallTelescope (← inferType e').bindingDomain! fun xs goal => do
969-
if xs.size ≠ 2 then
970-
throwError "expected recursor argument to take 2 parameters, got {xs}" else
971-
let targets : Array Expr := xs[*...1]
972-
let genIH := xs[1]!
973-
let extraParams := xs[2...*]
974-
-- open body with the same arg
975-
let body ← instantiateLambda body targets
976-
lambdaTelescope1 body fun oldIH body => do
977-
let body ← instantiateLambda body extraParams
978-
let body' ← withRewrittenMotiveArg goal (rwFun #[name]) fun goal => do
979-
buildInductionBody #[oldIH, genIH.fvarId!] #[] goal oldIH genIH.fvarId! isRecCall body
980-
if body'.containsFVar oldIH then
981-
throwError m!"Did not fully eliminate `{mkFVar oldIH}` from induction principle body:{indentExpr body}"
982-
mkLambdaFVars (targets.push genIH) (← mkLambdaFVars extraParams body')
983-
let e' := mkApp2 e' body' target
984-
let e' ← mkLambdaFVars #[target] e'
985-
let e' ← abstractIndependentMVars mvars (← motive.fvarId!.getDecl).index e'
986-
let e' ← mkLambdaFVars #[motive] e'
987-
988-
-- We used to pass (usedOnly := false) below in the hope that the types of the
989-
-- induction principle match the type of the function better.
990-
-- But this leads to avoidable parameters that make functional induction strictly less
991-
-- useful (e.g. when the unused parameter mentions bound variables in the users' goal)
992-
let (paramMask, e') ← mkLambdaFVarsMasked fixedParamPerms e'
993-
let e' ← instantiateMVars e'
994-
return (e', paramMask)
995-
| _ =>
996-
if funBody.isAppOf ``WellFounded.fix then
997-
throwError "Function `{name}` defined via `{.ofConstName ``WellFounded.fix}` with unexpected arity {funBody.getAppNumArgs}:{indentExpr funBody}"
998-
else
999-
throwError "Function `{name}` not defined via `{.ofConstName ``WellFounded.fix}`:{indentExpr funBody}"
978+
throwError "Function {name} not defined via WellFounded.fix:{indentExpr funBody}"
979+
check e'
980+
let (body', mvars) ← M2.run do
981+
forallTelescope (← inferType e').bindingDomain! fun xs goal => do
982+
if xs.size ≠ 2 then
983+
throwError "expected recursor argument to take 2 parameters, got {xs}" else
984+
let targets : Array Expr := xs[*...1]
985+
let genIH := xs[1]!
986+
let extraParams := xs[2...*]
987+
-- open body with the same arg
988+
let body ← instantiateLambda body targets
989+
lambdaTelescope1 body fun oldIH body => do
990+
let body ← instantiateLambda body extraParams
991+
let body' ← withRewrittenMotiveArg goal (rwFun #[name]) fun goal => do
992+
buildInductionBody #[oldIH, genIH.fvarId!] #[] goal oldIH genIH.fvarId! isRecCall body
993+
if body'.containsFVar oldIH then
994+
throwError m!"Did not fully eliminate `{mkFVar oldIH}` from induction principle body:{indentExpr body}"
995+
mkLambdaFVars (targets.push genIH) (← mkLambdaFVars extraParams body')
996+
let e' := mkApp2 e' body' target
997+
let e' ← mkLambdaFVars #[target] e'
998+
let e' ← abstractIndependentMVars mvars (← motive.fvarId!.getDecl).index e'
999+
let e' ← mkLambdaFVars #[motive] e'
1000+
1001+
-- We used to pass (usedOnly := false) below in the hope that the types of the
1002+
-- induction principle match the type of the function better.
1003+
-- But this leads to avoidable parameters that make functional induction strictly less
1004+
-- useful (e.g. when the unused parameter mentions bound variables in the users' goal)
1005+
let (paramMask, e') ← mkLambdaFVarsMasked fixedParamPerms e'
1006+
let e' ← instantiateMVars e'
1007+
return (e', paramMask)
10001008

10011009
unless (← isTypeCorrect e') do
10021010
logError m!"failed to derive a type-correct induction principle:{indentExpr e'}"

src/Std/Sat/AIG/Basic.lean

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,7 @@ where
485485
let lval := go lhs.gate decls assign (by omega) h2
486486
let rval := go rhs.gate decls assign (by omega) h2
487487
xor lval lhs.invert && xor rval rhs.invert
488+
termination_by (x, 0) -- Don't allow reduction, we have large concrete gate entries
488489

489490
/--
490491
Denotation of an `AIG` at a specific `Entrypoint`.

stage0/src/stdlib_flags.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "util/options.h"
22

3+
// please update stage0
4+
35
namespace lean {
46
options get_default_options() {
57
options opts;

tests/lean/run/1026.lean

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ info: foo.eq_def (n : Nat) :
1616
if n = 0 then 0
1717
else
1818
have x := n - 1;
19-
have this := foo._proof_4;
19+
have this := foo._proof_3;
2020
foo x
2121
-/
2222
#guard_msgs in
@@ -28,7 +28,7 @@ info: foo.eq_unfold :
2828
if n = 0 then 0
2929
else
3030
have x := n - 1;
31-
have this := foo._proof_4;
31+
have this := foo._proof_3;
3232
foo x
3333
-/
3434
#guard_msgs in

tests/lean/run/4928.lean

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ end
4646
/--
4747
error: Failed: `fail` tactic was invoked
4848
x : List Nat
49-
(invImage (fun x => PSum.casesOn x (fun x => x.length) fun x => x.length) sizeOfWFRel).1 (PSum.inr x.tail)
49+
InvImage (fun x1 x2 => x1 < x2) (fun x => PSum.casesOn x (fun x => x.length) fun x => x.length) (PSum.inr x.tail)
5050
(PSum.inl x)
5151
-/
5252
#guard_msgs in

0 commit comments

Comments
 (0)