|
| 1 | +/- |
| 2 | +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. |
| 3 | +Released under Apache 2.0 license as described in the file LICENSE. |
| 4 | +Authors: Paul Reichert |
| 5 | +-/ |
| 6 | +module |
| 7 | + |
| 8 | +prelude |
| 9 | +public import Init.Data.Nat.Lemmas |
| 10 | +public import Init.Data.Iterators.Consumers.Monadic.Collect |
| 11 | +public import Init.Data.Iterators.Consumers.Monadic.Loop |
| 12 | +public import Init.Data.Iterators.Internal.Termination |
| 13 | + |
| 14 | +@[expose] public section |
| 15 | + |
| 16 | +/-! |
| 17 | +This module provides the iterator combinator `IterM.take`. |
| 18 | +-/ |
| 19 | + |
| 20 | +namespace Std.Iterators |
| 21 | + |
| 22 | +variable {α : Type w} {m : Type w → Type w'} {β : Type w} |
| 23 | + |
| 24 | +/-- |
| 25 | +The internal state of the `IterM.take` iterator combinator. |
| 26 | +-/ |
| 27 | +@[unbox] |
| 28 | +structure Take (α : Type w) (m : Type w → Type w') {β : Type w} [Iterator α m β] where |
| 29 | + /-- |
| 30 | + Internal implementation detail of the iterator library. |
| 31 | + Caution: For `take n`, `countdown` is `n + 1`. |
| 32 | + If `countdown` is zero, the combinator only terminates when `inner` terminates. |
| 33 | + -/ |
| 34 | + countdown : Nat |
| 35 | + /-- Internal implementation detail of the iterator library -/ |
| 36 | + inner : IterM (α := α) m β |
| 37 | + /-- |
| 38 | + Internal implementation detail of the iterator library. |
| 39 | + This proof term ensures that a `take` always produces a finite iterator from a productive one. |
| 40 | + -/ |
| 41 | + finite : countdown > 0 ∨ Finite α m |
| 42 | + |
| 43 | +/-- |
| 44 | +Given an iterator `it` and a natural number `n`, `it.take n` is an iterator that outputs |
| 45 | +up to the first `n` of `it`'s values in order and then terminates. |
| 46 | +
|
| 47 | +**Marble diagram:** |
| 48 | +
|
| 49 | +```text |
| 50 | +it ---a----b---c--d-e--⊥ |
| 51 | +it.take 3 ---a----b---c⊥ |
| 52 | +
|
| 53 | +it ---a--⊥ |
| 54 | +it.take 3 ---a--⊥ |
| 55 | +``` |
| 56 | +
|
| 57 | +**Termination properties:** |
| 58 | +
|
| 59 | +* `Finite` instance: only if `it` is productive |
| 60 | +* `Productive` instance: only if `it` is productive |
| 61 | +
|
| 62 | +**Performance:** |
| 63 | +
|
| 64 | +This combinator incurs an additional O(1) cost with each output of `it`. |
| 65 | +-/ |
| 66 | +@[always_inline, inline] |
| 67 | +def IterM.take [Iterator α m β] (n : Nat) (it : IterM (α := α) m β) := |
| 68 | + toIterM (Take.mk (n + 1) it (Or.inl <| Nat.zero_lt_succ _)) m β |
| 69 | + |
| 70 | +/-- |
| 71 | +This combinator is only useful for advanced use cases. |
| 72 | +
|
| 73 | +Given a finite iterator `it`, returns an iterator that behaves exactly like `it` but is of the same |
| 74 | +type as `it.take n`. |
| 75 | +
|
| 76 | +**Marble diagram:** |
| 77 | +
|
| 78 | +```text |
| 79 | +it ---a----b---c--d-e--⊥ |
| 80 | +it.toTake ---a----b---c--d-e--⊥ |
| 81 | +``` |
| 82 | +
|
| 83 | +**Termination properties:** |
| 84 | +
|
| 85 | +* `Finite` instance: always |
| 86 | +* `Productive` instance: always |
| 87 | +
|
| 88 | +**Performance:** |
| 89 | +
|
| 90 | +This combinator incurs an additional O(1) cost with each output of `it`. |
| 91 | +-/ |
| 92 | +@[always_inline, inline] |
| 93 | +def IterM.toTake [Iterator α m β] [Finite α m] (it : IterM (α := α) m β) := |
| 94 | + toIterM (Take.mk 0 it (Or.inr inferInstance)) m β |
| 95 | + |
| 96 | +theorem IterM.take.surjective_of_zero_lt {α : Type w} {m : Type w → Type w'} {β : Type w} |
| 97 | + [Iterator α m β] (it : IterM (α := Take α m) m β) (h : 0 < it.internalState.countdown) : |
| 98 | + ∃ (it₀ : IterM (α := α) m β) (k : Nat), it = it₀.take k := by |
| 99 | + refine ⟨it.internalState.inner, it.internalState.countdown - 1, ?_⟩ |
| 100 | + simp only [take, Nat.sub_add_cancel (m := 1) (n := it.internalState.countdown) (by omega)] |
| 101 | + rfl |
| 102 | + |
| 103 | +inductive Take.PlausibleStep [Iterator α m β] (it : IterM (α := Take α m) m β) : |
| 104 | + (step : IterStep (IterM (α := Take α m) m β) β) → Prop where |
| 105 | + | yield : ∀ {it' out}, it.internalState.inner.IsPlausibleStep (.yield it' out) → |
| 106 | + (h : it.internalState.countdown ≠ 1) → PlausibleStep it (.yield ⟨it.internalState.countdown - 1, it', it.internalState.finite.imp_left (by omega)⟩ out) |
| 107 | + | skip : ∀ {it'}, it.internalState.inner.IsPlausibleStep (.skip it') → |
| 108 | + it.internalState.countdown ≠ 1 → PlausibleStep it (.skip ⟨it.internalState.countdown, it', it.internalState.finite⟩) |
| 109 | + | done : it.internalState.inner.IsPlausibleStep .done → PlausibleStep it .done |
| 110 | + | depleted : it.internalState.countdown = 1 → |
| 111 | + PlausibleStep it .done |
| 112 | + |
| 113 | +@[always_inline, inline] |
| 114 | +instance Take.instIterator [Monad m] [Iterator α m β] : Iterator (Take α m) m β where |
| 115 | + IsPlausibleStep := Take.PlausibleStep |
| 116 | + step it := |
| 117 | + if h : it.internalState.countdown = 1 then |
| 118 | + pure <| .deflate <| .done (.depleted h) |
| 119 | + else do |
| 120 | + match (← it.internalState.inner.step).inflate with |
| 121 | + | .yield it' out h' => |
| 122 | + pure <| .deflate <| .yield ⟨it.internalState.countdown - 1, it', (it.internalState.finite.imp_left (by omega))⟩ out (.yield h' h) |
| 123 | + | .skip it' h' => pure <| .deflate <| .skip ⟨it.internalState.countdown, it', it.internalState.finite⟩ (.skip h' h) |
| 124 | + | .done h' => pure <| .deflate <| .done (.done h') |
| 125 | + |
| 126 | +def Take.Rel (m : Type w → Type w') [Monad m] [Iterator α m β] [Productive α m] : |
| 127 | + IterM (α := Take α m) m β → IterM (α := Take α m) m β → Prop := |
| 128 | + open scoped Classical in |
| 129 | + if _ : Finite α m then |
| 130 | + InvImage (Prod.Lex Nat.lt_wfRel.rel IterM.TerminationMeasures.Finite.Rel) |
| 131 | + (fun it => (it.internalState.countdown, it.internalState.inner.finitelyManySteps)) |
| 132 | + else |
| 133 | + InvImage (Prod.Lex Nat.lt_wfRel.rel IterM.TerminationMeasures.Productive.Rel) |
| 134 | + (fun it => (it.internalState.countdown, it.internalState.inner.finitelyManySkips)) |
| 135 | + |
| 136 | +theorem Take.rel_of_countdown [Monad m] [Iterator α m β] [Productive α m] |
| 137 | + {it it' : IterM (α := Take α m) m β} |
| 138 | + (h : it'.internalState.countdown < it.internalState.countdown) : Take.Rel m it' it := by |
| 139 | + simp only [Rel] |
| 140 | + split <;> exact Prod.Lex.left _ _ h |
| 141 | + |
| 142 | +theorem Take.rel_of_inner [Monad m] [Iterator α m β] [Productive α m] {remaining : Nat} |
| 143 | + {it it' : IterM (α := α) m β} |
| 144 | + (h : it'.finitelyManySkips.Rel it.finitelyManySkips) : |
| 145 | + Take.Rel m (it'.take remaining) (it.take remaining) := by |
| 146 | + simp only [Rel] |
| 147 | + split |
| 148 | + · exact Prod.Lex.right _ (.of_productive h) |
| 149 | + · exact Prod.Lex.right _ h |
| 150 | + |
| 151 | +theorem Take.rel_of_zero_of_inner [Monad m] [Iterator α m β] |
| 152 | + {it it' : IterM (α := Take α m) m β} |
| 153 | + (h : it.internalState.countdown = 0) (h' : it'.internalState.countdown = 0) |
| 154 | + (h'' : haveI := it.internalState.finite.resolve_left (by omega); it'.internalState.inner.finitelyManySteps.Rel it.internalState.inner.finitelyManySteps) : |
| 155 | + haveI := it.internalState.finite.resolve_left (by omega) |
| 156 | + Take.Rel m it' it := by |
| 157 | + haveI := it.internalState.finite.resolve_left (by omega) |
| 158 | + simp only [Rel, this, ↓reduceDIte, InvImage, h, h'] |
| 159 | + exact Prod.Lex.right _ h'' |
| 160 | + |
| 161 | +private def Take.instFinitenessRelation [Monad m] [Iterator α m β] |
| 162 | + [Productive α m] : |
| 163 | + FinitenessRelation (Take α m) m where |
| 164 | + rel := Take.Rel m |
| 165 | + wf := by |
| 166 | + rw [Rel] |
| 167 | + split |
| 168 | + all_goals |
| 169 | + apply InvImage.wf |
| 170 | + refine ⟨fun (a, b) => Prod.lexAccessible (WellFounded.apply ?_ a) (WellFounded.apply ?_) b⟩ |
| 171 | + · exact WellFoundedRelation.wf |
| 172 | + · exact WellFoundedRelation.wf |
| 173 | + subrelation {it it'} h := by |
| 174 | + obtain ⟨step, h, h'⟩ := h |
| 175 | + cases h' |
| 176 | + case yield it' out k h' h'' => |
| 177 | + cases h |
| 178 | + cases it.internalState.finite |
| 179 | + · apply rel_of_countdown |
| 180 | + simp only |
| 181 | + omega |
| 182 | + · by_cases h : it.internalState.countdown = 0 |
| 183 | + · simp only [h, Nat.zero_le, Nat.sub_eq_zero_of_le] |
| 184 | + apply rel_of_zero_of_inner h rfl |
| 185 | + exact .single ⟨_, rfl, h'⟩ |
| 186 | + · apply rel_of_countdown |
| 187 | + simp only |
| 188 | + omega |
| 189 | + case skip it' out k h' h'' => |
| 190 | + cases h |
| 191 | + by_cases h : it.internalState.countdown = 0 |
| 192 | + · simp only [h] |
| 193 | + apply Take.rel_of_zero_of_inner h rfl |
| 194 | + exact .single ⟨_, rfl, h'⟩ |
| 195 | + · obtain ⟨it, k, rfl⟩ := IterM.take.surjective_of_zero_lt it (by omega) |
| 196 | + apply Take.rel_of_inner |
| 197 | + exact IterM.TerminationMeasures.Productive.rel_of_skip h' |
| 198 | + case done _ => |
| 199 | + cases h |
| 200 | + case depleted _ => |
| 201 | + cases h |
| 202 | + |
| 203 | +instance Take.instFinite [Monad m] [Iterator α m β] [Productive α m] : |
| 204 | + Finite (Take α m) m := |
| 205 | + by exact Finite.of_finitenessRelation instFinitenessRelation |
| 206 | + |
| 207 | +instance Take.instIteratorCollect {n : Type w → Type w'} [Monad m] [Monad n] [Iterator α m β] : |
| 208 | + IteratorCollect (Take α m) m n := |
| 209 | + .defaultImplementation |
| 210 | + |
| 211 | +instance Take.instIteratorCollectPartial {n : Type w → Type w'} [Monad m] [Monad n] [Iterator α m β] : |
| 212 | + IteratorCollectPartial (Take α m) m n := |
| 213 | + .defaultImplementation |
| 214 | + |
| 215 | +instance Take.instIteratorLoop {n : Type x → Type x'} [Monad m] [Monad n] [Iterator α m β] : |
| 216 | + IteratorLoop (Take α m) m n := |
| 217 | + .defaultImplementation |
| 218 | + |
| 219 | +instance Take.instIteratorLoopPartial [Monad m] [Monad n] [Iterator α m β] : |
| 220 | + IteratorLoopPartial (Take α m) m n := |
| 221 | + .defaultImplementation |
| 222 | + |
| 223 | +end Std.Iterators |
0 commit comments