Skip to content

Commit 3bf95e9

Browse files
kim-emeric-wieser
andauthored
feat: add List/Array/Vector.ofFnM (#8389)
This PR adds the `List/Array/Vector.ofFnM`, the monadic analogues of `ofFn`, along with basic theory. At the same time we pave some potholes in nearby API. --------- Co-authored-by: Eric Wieser <[email protected]>
1 parent bc21b57 commit 3bf95e9

File tree

16 files changed

+571
-41
lines changed

16 files changed

+571
-41
lines changed

src/Init/Data/Array/Basic.lean

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ theorem mem_def {a : α} {as : Array α} : a ∈ as ↔ a ∈ as.toList :=
112112
rw [Array.mem_def, ← getElem_toList]
113113
apply List.getElem_mem
114114

115+
@[simp] theorem emptyWithCapacity_eq {α n} : @emptyWithCapacity α n = #[] := rfl
116+
117+
@[simp] theorem mkEmpty_eq {α n} : @mkEmpty α n = #[] := rfl
118+
115119
end Array
116120

117121
namespace List
@@ -334,6 +338,8 @@ def ofFn {n} (f : Fin n → α) : Array α := go 0 (emptyWithCapacity n) where
334338
if h : i < n then go (i+1) (acc.push (f ⟨i, h⟩)) else acc
335339
decreasing_by simp_wf; decreasing_trivial_pre_omega
336340
341+
-- See also `Array.ofFnM` defined in `Init.Data.Array.OfFn`.
342+
337343
/--
338344
Constructs an array that contains all the numbers from `0` to `n`, exclusive.
339345

src/Init/Data/Array/Lemmas.lean

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,6 @@ theorem toArray_eq : List.toArray as = xs ↔ as = xs.toList := by
6161

6262
@[grind] theorem size_empty : (#[] : Array α).size = 0 := rfl
6363

64-
@[simp] theorem emptyWithCapacity_eq {α n} : @emptyWithCapacity α n = #[] := rfl
65-
66-
@[deprecated emptyWithCapacity_eq (since := "2025-03-12")]
67-
theorem mkEmpty_eq {α n} : @mkEmpty α n = #[] := rfl
68-
6964
/-! ### size -/
7065

7166
@[grind →] theorem eq_empty_of_size_eq_zero (h : xs.size = 0) : xs = #[] := by
@@ -247,6 +242,12 @@ theorem back?_pop {xs : Array α} :
247242

248243
/-! ### push -/
249244

245+
@[simp] theorem push_empty : #[].push x = #[x] := rfl
246+
247+
@[simp] theorem toList_push {xs : Array α} {x : α} : (xs.push x).toList = xs.toList ++ [x] := by
248+
rcases xs with ⟨xs⟩
249+
simp
250+
250251
@[simp] theorem push_ne_empty {a : α} {xs : Array α} : xs.push a ≠ #[] := by
251252
cases xs
252253
simp
@@ -3266,6 +3267,22 @@ rather than `(arr.push a).size` as the argument.
32663267
(xs.push a).foldrM f init start = f a init >>= xs.foldrM f := by
32673268
simp [← foldrM_push, h]
32683269

3270+
@[simp, grind] theorem _root_.List.foldrM_push_eq_append [Monad m] [LawfulMonad m] {l : List α} {f : α → m β} {xs : Array β} :
3271+
l.foldrM (fun x xs => xs.push <$> f x) xs = do return xs ++ (← l.reverse.mapM f).toArray := by
3272+
induction l with
3273+
| nil => simp
3274+
| cons a l ih =>
3275+
simp [ih]
3276+
congr 1
3277+
funext l'
3278+
congr 1
3279+
funext x
3280+
simp
3281+
3282+
@[simp, grind] theorem _root_.List.foldlM_push_eq_append [Monad m] [LawfulMonad m] {l : List α} {f : α → m β} {xs : Array β} :
3283+
l.foldlM (fun xs x => xs.push <$> f x) xs = do return xs ++ (← l.mapM f).toArray := by
3284+
induction l generalizing xs <;> simp [*]
3285+
32693286
/-! ### foldl / foldr -/
32703287

32713288
@[grind] theorem foldl_empty {f : β → α → β} {init : β} : (#[].foldl f init) = init := rfl
@@ -3362,6 +3379,32 @@ rather than `(arr.push a).size` as the argument.
33623379
rcases as with ⟨as⟩
33633380
simp
33643381

3382+
@[simp, grind] theorem _root_.List.foldr_push_eq_append {l : List α} {f : α → β} {xs : Array β} :
3383+
l.foldr (fun x xs => xs.push (f x)) xs = xs ++ (l.reverse.map f).toArray := by
3384+
induction l <;> simp [*]
3385+
3386+
/-- Variant of `List.foldr_push_eq_append` specialized to `f = id`. -/
3387+
@[simp, grind] theorem _root_.List.foldr_push_eq_append' {l : List α} {xs : Array α} :
3388+
l.foldr (fun x xs => xs.push x) xs = xs ++ l.reverse.toArray := by
3389+
induction l <;> simp [*]
3390+
3391+
@[simp, grind] theorem _root_.List.foldl_push_eq_append {l : List α} {f : α → β} {xs : Array β} :
3392+
l.foldl (fun xs x => xs.push (f x)) xs = xs ++ (l.map f).toArray := by
3393+
induction l generalizing xs <;> simp [*]
3394+
3395+
/-- Variant of `List.foldl_push_eq_append` specialized to `f = id`. -/
3396+
@[simp, grind] theorem _root_.List.foldl_push_eq_append' {l : List α} {xs : Array α} :
3397+
l.foldl (fun xs x => xs.push x) xs = xs ++ l.toArray := by
3398+
simpa using List.foldl_push_eq_append (f := id)
3399+
3400+
@[deprecated _root_.List.foldl_push_eq_append' (since := "2025-05-18")]
3401+
theorem _root_.List.foldl_push {l : List α} {as : Array α} : l.foldl Array.push as = as ++ l.toArray := by
3402+
induction l generalizing as <;> simp [*]
3403+
3404+
@[deprecated _root_.List.foldr_push_eq_append' (since := "2025-05-18")]
3405+
theorem _root_.List.foldr_push {l : List α} {as : Array α} : l.foldr (fun a bs => push bs a) as = as ++ l.reverse.toArray := by
3406+
rw [List.foldr_eq_foldl_reverse, List.foldl_push_eq_append']
3407+
33653408
@[simp, grind] theorem foldr_append_eq_append {xs : Array α} {f : α → Array β} {ys : Array β} :
33663409
xs.foldr (f · ++ ·) ys = (xs.map f).flatten ++ ys := by
33673410
rcases xs with ⟨xs⟩

src/Init/Data/Array/Monadic.lean

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ open Nat
2525

2626
/-! ## Monadic operations -/
2727

28+
@[simp] theorem map_toList_inj [Monad m] [LawfulMonad m]
29+
{xs : m (Array α)} {ys : m (Array α)} :
30+
toList <$> xs = toList <$> ys ↔ xs = ys :=
31+
_root_.map_inj_right (by simp)
32+
2833
/-! ### mapM -/
2934

3035
@[simp] theorem mapM_pure [Monad m] [LawfulMonad m] {xs : Array α} {f : α → β} :
@@ -34,6 +39,11 @@ open Nat
3439
@[simp] theorem mapM_id {xs : Array α} {f : α → Id β} : xs.mapM f = xs.map f :=
3540
mapM_pure
3641

42+
@[simp] theorem mapM_map [Monad m] [LawfulMonad m] {f : α → β} {g : β → m γ} {xs : Array α} :
43+
(xs.map f).mapM g = xs.mapM (g ∘ f) := by
44+
rcases xs with ⟨xs⟩
45+
simp
46+
3747
@[simp] theorem mapM_append [Monad m] [LawfulMonad m] {f : α → m β} {xs ys : Array α} :
3848
(xs ++ ys).mapM f = (return (← xs.mapM f) ++ (← ys.mapM f)) := by
3949
rcases xs with ⟨xs⟩

src/Init/Data/Array/OfFn.lean

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ module
88
prelude
99
import all Init.Data.Array.Basic
1010
import Init.Data.Array.Lemmas
11+
import Init.Data.Array.Monadic
1112
import Init.Data.List.OfFn
13+
import Init.Data.List.FinRange
1214

1315
/-!
1416
# Theorems about `Array.ofFn`
@@ -19,6 +21,8 @@ set_option linter.indexVariables true -- Enforce naming conventions for index va
1921

2022
namespace Array
2123

24+
/-! ### ofFn -/
25+
2226
@[simp] theorem ofFn_zero {f : Fin 0 → α} : ofFn f = #[] := by
2327
simp [ofFn, ofFn.go]
2428

@@ -32,12 +36,23 @@ theorem ofFn_succ {f : Fin (n+1) → α} :
3236
intro h₃
3337
simp only [show i = n by omega]
3438

39+
theorem ofFn_add {n m} {f : Fin (n + m) → α} :
40+
ofFn f = (ofFn (fun i => f (i.castLE (Nat.le_add_right n m)))) ++ (ofFn (fun i => f (i.natAdd n))) := by
41+
induction m with
42+
| zero => simp
43+
| succ m ih => simp [ofFn_succ, ih]
44+
3545
@[simp] theorem _root_.List.toArray_ofFn {f : Fin n → α} : (List.ofFn f).toArray = Array.ofFn f := by
3646
ext <;> simp
3747

3848
@[simp] theorem toList_ofFn {f : Fin n → α} : (Array.ofFn f).toList = List.ofFn f := by
3949
apply List.ext_getElem <;> simp
4050

51+
theorem ofFn_succ' {f : Fin (n+1) → α} :
52+
ofFn f = #[f 0] ++ ofFn (fun i => f i.succ) := by
53+
apply Array.toList_inj.mp
54+
simp [List.ofFn_succ]
55+
4156
@[simp]
4257
theorem ofFn_eq_empty_iff {f : Fin n → α} : ofFn f = #[] ↔ n = 0 := by
4358
rw [← Array.toList_inj]
@@ -52,4 +67,71 @@ theorem mem_ofFn {n} {f : Fin n → α} {a : α} : a ∈ ofFn f ↔ ∃ i, f i =
5267
· rintro ⟨i, rfl⟩
5368
apply mem_of_getElem (i := i) <;> simp
5469

70+
/-! ### ofFnM -/
71+
72+
/-- Construct (in a monadic context) an array by applying a monadic function to each index. -/
73+
def ofFnM {n} [Monad m] (f : Fin n → m α) : m (Array α) :=
74+
Fin.foldlM n (fun xs i => xs.push <$> f i) (Array.emptyWithCapacity n)
75+
76+
@[simp]
77+
theorem ofFnM_zero [Monad m] {f : Fin 0 → m α} : ofFnM f = pure #[] := by
78+
simp [ofFnM]
79+
80+
theorem ofFnM_succ' {n} [Monad m] [LawfulMonad m] {f : Fin (n + 1) → m α} :
81+
ofFnM f = (do
82+
let a ← f 0
83+
let as ← ofFnM fun i => f i.succ
84+
pure (#[a] ++ as)) := by
85+
simp [ofFnM, Fin.foldlM_eq_finRange_foldlM, List.foldlM_push_eq_append, List.finRange_succ, Function.comp_def]
86+
87+
theorem ofFnM_succ {n} [Monad m] [LawfulMonad m] {f : Fin (n + 1) → m α} :
88+
ofFnM f = (do
89+
let as ← ofFnM fun i => f i.castSucc
90+
let a ← f (Fin.last n)
91+
pure (as.push a)) := by
92+
simp [ofFnM, Fin.foldlM_succ_last]
93+
94+
theorem ofFnM_add {n m} [Monad m] [LawfulMonad m] {f : Fin (n + k) → m α} :
95+
ofFnM f = (do
96+
let as ← ofFnM fun i : Fin n => f (i.castLE (Nat.le_add_right n k))
97+
let bs ← ofFnM fun i : Fin k => f (i.natAdd n)
98+
pure (as ++ bs)) := by
99+
induction k with
100+
| zero => simp
101+
| succ k ih =>
102+
simp only [ofFnM_succ, Nat.add_eq, ih, Fin.castSucc_castLE, Fin.castSucc_natAdd, bind_pure_comp,
103+
bind_assoc, bind_map_left, Fin.natAdd_last, map_bind, Functor.map_map]
104+
congr 1
105+
funext xs
106+
congr 1
107+
funext ys
108+
congr 1
109+
funext x
110+
simp
111+
112+
@[simp] theorem toList_ofFnM [Monad m] [LawfulMonad m] {f : Fin n → m α} :
113+
toList <$> ofFnM f = List.ofFnM f := by
114+
induction n with
115+
| zero => simp
116+
| succ n ih => simp [ofFnM_succ, List.ofFnM_succ_last, ← ih]
117+
118+
@[simp]
119+
theorem ofFnM_pure_comp [Monad m] [LawfulMonad m] {n} {f : Fin n → α} :
120+
ofFnM (pure ∘ f) = (pure (ofFn f) : m (Array α)) := by
121+
apply Array.map_toList_inj.mp
122+
simp
123+
124+
-- Variant of `ofFnM_pure_comp` using a lambda.
125+
-- This is not marked a `@[simp]` as it would match on every occurrence of `ofFnM`.
126+
theorem ofFnM_pure [Monad m] [LawfulMonad m] {n} {f : Fin n → α} :
127+
ofFnM (fun i => pure (f i)) = (pure (ofFn f) : m (Array α)) :=
128+
ofFnM_pure_comp
129+
130+
@[simp, grind =] theorem idRun_ofFnM {f : Fin n → Id α} :
131+
Id.run (ofFnM f) = ofFn (fun i => Id.run (f i)) := by
132+
unfold Id.run
133+
induction n with
134+
| zero => simp
135+
| succ n ih => simp [ofFnM_succ', ofFn_succ', ih]
136+
55137
end Array

0 commit comments

Comments
 (0)