Skip to content

Commit 1a3901a

Browse files
committed
Prevent accidential reduction of open fuel
1 parent d08c0de commit 1a3901a

File tree

2 files changed

+91
-3
lines changed

2 files changed

+91
-3
lines changed

src/Init/WF.lean

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -446,12 +446,27 @@ variable {motive : α → Sort v}
446446
variable (h : α → Nat)
447447
variable (F : (x : α) → ((y : α) → InvImage (· < ·) h y x → motive y) → motive x)
448448

449+
/-- Helper gadget that prevents reduction of `Nat.eager n` unless `n` evalutes to a ground term. -/
450+
def Nat.eager (n : Nat) : Nat :=
451+
if Nat.beq n n = true then n else n
452+
453+
theorem Nat.eager_eq (n : Nat) : Nat.eager n = n := ite_self n
454+
455+
/--
456+
A well-founded fixpoint operator specialized for `Nat`-valued measures. Given a measure `h`, it expects
457+
its higher order function argument `F` to invoke its argument only on values `y` that are smaller
458+
than `x` with regard to `h`.
459+
460+
In contrast to to `WellFounded.fix`, this fixpoint operator reduces on closed terms. (More precisely:
461+
when `h x` evalutes to a ground value)
462+
463+
-/
449464
def Nat.fix : (x : α) → motive x :=
450465
let rec go : ∀ (fuel : Nat) (x : α), (h x < fuel) → motive x :=
451466
Nat.rec
452467
(fun _ hfuel => (Nat.not_succ_le_zero _ hfuel).elim)
453468
(fun _ ih x hfuel => F x (fun y hy => ih y (Nat.lt_of_lt_of_le hy (Nat.le_of_lt_add_one hfuel))))
454-
fun x => go (h x + 1) x (Nat.lt_add_one _)
469+
fun x => go (Nat.eager (h x + 1)) x (Nat.eager_eq _ ▸ Nat.lt_add_one _)
455470

456471
protected theorem Nat.fix.go_congr (x : α) (fuel₁ fuel₂ : Nat) (h₁ : h x < fuel₁) (h₂ : h x < fuel₂) :
457472
Nat.fix.go h F fuel₁ x h₁ = Nat.fix.go h F fuel₂ x h₂ := by
@@ -464,8 +479,10 @@ protected theorem Nat.fix.go_congr (x : α) (fuel₁ fuel₂ : Nat) (h₁ : h x
464479
exact congrArg (F x) (funext fun y => funext fun hy => ih y fuel₂ _ _ )
465480

466481
theorem Nat.fix_eq (x : α) :
467-
Nat.fix h F x = F x (fun y _ => Nat.fix h F y) :=
468-
congrArg (F x) (funext fun _ => funext fun _ => Nat.fix.go_congr ..)
482+
Nat.fix h F x = F x (fun y _ => Nat.fix h F y) := by
483+
unfold Nat.fix
484+
simp [Nat.eager_eq]
485+
exact congrArg (F x) (funext fun _ => funext fun _ => Nat.fix.go_congr ..)
469486

470487
end WellFounded
471488

tests/lean/run/wfrec-nat.lean

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/-!
2+
Tests around the special case of well-founded recursion on Nat.
3+
-/
4+
5+
namespace T1
6+
7+
def foo : List α → Nat
8+
| [] => 0
9+
| _::xs => 1 + (foo xs)
10+
termination_by xs => xs.length
11+
12+
-- Closed terms should evaluate
13+
14+
example : foo ([] : List Unit) = 0 := rfl
15+
example : foo ([] : List Unit) = 0 := by decide
16+
example : foo ([] : List Unit) = 0 := by decide +kernel
17+
example : foo [1,2,3,4,5] = 5 := rfl
18+
example : foo [1,2,3,4,5] = 5 := by decide
19+
example : foo [1,2,3,4,5] = 5 := by decide +kernel
20+
21+
-- Open terms should not (these wouldn't even without the provisions with `WellFounded.Nat.eager`,
22+
-- the fuel does not line up)
23+
24+
example : foo (x::xs) = 1 + foo xs := by (fail_if_success rfl); simp [foo]
25+
example : foo (x::y::z::xs) = 1+ (1+(1+ foo xs)) := by (fail_if_success rfl); simp [foo]
26+
27+
end T1
28+
29+
-- Variant where the fuel does not line up
30+
31+
namespace T2
32+
def foo : List α → Nat
33+
| [] => 0
34+
| _::xs => 1 + (foo xs)
35+
termination_by xs => 2 * xs.length
36+
37+
example : foo ([] : List Unit) = 0 := rfl
38+
example : foo ([] : List Unit) = 0 := by decide
39+
example : foo ([] : List Unit) = 0 := by decide +kernel
40+
example : foo [1,2,3,4,5] = 5 := rfl
41+
example : foo [1,2,3,4,5] = 5 := by decide
42+
example : foo [1,2,3,4,5] = 5 := by decide +kernel
43+
44+
-- Open terms should not (these wouldn't even without the provisions, the fuel does not line up)
45+
46+
example : foo (x::xs) = 1 + foo xs := by (fail_if_success rfl); simp [foo]
47+
example : foo (x::y::z::xs) = 1+ (1 + ( 1+ foo xs)) := by (fail_if_success rfl); simp [foo]
48+
49+
end T2
50+
51+
-- Idiom to switch to `WellFounded.fix`
52+
53+
namespace T3
54+
def foo : List α → Nat
55+
| [] => 0
56+
| _::xs => 1 + (foo xs)
57+
termination_by xs => (xs.length, 0)
58+
59+
example : foo ([] : List Unit) = 0 := by (fail_if_success rfl); simp [foo]
60+
example : foo ([] : List Unit) = 0 := by (fail_if_success decide); simp [foo]
61+
example : foo ([] : List Unit) = 0 := by (fail_if_success decide +kernel); simp [foo]
62+
example : foo [1,2,3,4,5] = 5 := by (fail_if_success rfl); simp [foo]
63+
example : foo [1,2,3,4,5] = 5 := by (fail_if_success decide); simp [foo]
64+
example : foo [1,2,3,4,5] = 5 := by (fail_if_success decide +kernel); simp [foo]
65+
66+
-- Open terms should not (these wouldn't even without the provisions, the fuel does not line up)
67+
68+
example : foo (x::xs) = 1 + foo xs := by (fail_if_success rfl); simp [foo]
69+
example : foo (x::y::z::xs) = 1+ (1 + ( 1+ foo xs)) := by (fail_if_success rfl); simp [foo]
70+
71+
end T3

0 commit comments

Comments
 (0)