@@ -10,6 +10,7 @@ public import Init.Data.String.Pattern.Basic
1010public import Init.Data.Iterators.Internal.Termination
1111public import Init.Data.Iterators.Consumers.Monadic.Loop
1212import Init.Data.String.Termination
13+ public import Init.Data.Vector.Basic
1314
1415set_option doc.verso true
1516
@@ -22,67 +23,87 @@ public section
2223
2324namespace String.Slice.Pattern
2425
25- inductive ForwardSliceSearcher (s : Slice) where
26- | emptyBefore (pos : s.Pos)
27- | emptyAt (pos : s.Pos) (h : pos ≠ s.endPos)
28- | proper (needle : Slice) (table : Array String.Pos.Raw) (stackPos : String.Pos.Raw) (needlePos : String.Pos.Raw)
29- | atEnd
30- deriving Inhabited
31-
3226namespace ForwardSliceSearcher
3327
34- partial def buildTable (pat : Slice) : Array String.Pos.Raw :=
35- if pat.utf8ByteSize = = 0 then
36- #[]
28+ def buildTable (pat : Slice) : Vector Nat pat.utf8ByteSize :=
29+ if h : pat.utf8ByteSize = 0 then
30+ #v[].cast h.symm
3731 else
3832 let arr := Array.emptyWithCapacity pat.utf8ByteSize
39- let arr := arr.push 0
40- go ⟨ 1 ⟩ arr
33+ let arr' := arr.push 0
34+ go arr' ( by simp [ arr']) ( by simp [arr', arr]; omega) ( by simp [arr', arr])
4135where
42- go (pos : String.Pos.Raw) (table : Array String.Pos.Raw) :=
43- if h : pos < pat.rawEndPos then
44- let patByte := pat.getUTF8Byte pos h
45- let distance := computeDistance table[table.size - 1 ]! patByte table
46- let distance := if patByte = pat.getUTF8Byte! distance then distance.inc else distance
47- go pos.inc (table.push distance)
36+ go (table : Array Nat) (ht₀ : 0 < table.size) (ht : table.size ≤ pat.utf8ByteSize) (h : ∀ (i : Nat) hi, table[i]'hi ≤ i) :
37+ Vector Nat pat.utf8ByteSize :=
38+ if hs : table.size < pat.utf8ByteSize then
39+ let patByte := pat.getUTF8Byte ⟨table.size⟩ hs
40+ let dist := computeDistance patByte table ht h (table[table.size - 1 ])
41+ (by have := h (table.size - 1 ) (by omega); omega)
42+ let dist' := if pat.getUTF8Byte ⟨dist.1 ⟩ (by simp [Pos.Raw.lt_iff]; omega) = patByte then dist.1 + 1 else dist
43+ go (table.push dist') (by simp) (by simp; omega) (by
44+ intro i hi
45+ by_cases hi' : i = table.size
46+ · subst hi'
47+ simp [dist']
48+ have := dist.2
49+ split <;> omega
50+ · rw [Array.getElem_push_lt]
51+ · apply h
52+ · simp at hi
53+ omega)
4854 else
49- table
50-
51- computeDistance (distance : String.Pos.Raw) (patByte : UInt8) (table : Array String.Pos.Raw) :
52- String.Pos.Raw :=
53- if distance > 0 && patByte != pat.getUTF8Byte! distance then
54- computeDistance table[distance.byteIdx - 1 ]! patByte table
55+ Vector.mk table (by omega)
56+
57+ computeDistance (patByte : UInt8) (table : Array Nat)
58+ (ht : table.size ≤ pat.utf8ByteSize)
59+ (h : ∀ (i : Nat) hi, table[i]'hi ≤ i) (guess : Nat) (hg : guess < table.size) :
60+ { n : Nat // n < table.size } :=
61+ if h' : guess = 0 ∨ pat.getUTF8Byte ⟨guess⟩ (by simp [Pos.Raw.lt_iff]; omega) = patByte then
62+ ⟨guess, hg⟩
5563 else
56- distance
64+ have : table[guess - 1 ] < guess := by have := h (guess - 1 ) (by omega); omega
65+ computeDistance patByte table ht h table[guess - 1 ] (by omega)
66+
67+ theorem getElem_buildTable_le (pat : Slice) (i : Nat) (hi) : (buildTable pat)[i]'hi ≤ i := by
68+ rw [buildTable]
69+ split <;> rename_i h
70+ · simp [h] at hi
71+ · simp only [Array.emptyWithCapacity_eq, List.push_toArray, List.nil_append]
72+ suffices ∀ pat' table ht₀ ht h (i : Nat) hi, (buildTable.go pat' table ht₀ ht h)[i]'hi ≤ i from this ..
73+ intro pat' table ht₀ ht h i hi
74+ fun_induction buildTable.go with
75+ | case1 => assumption
76+ | case2 table ht₀ ht ht' ht'' => apply ht'
77+
78+ inductive _root_.String.Slice.Pattern.ForwardSliceSearcher (s : Slice) where
79+ | emptyBefore (pos : s.Pos)
80+ | emptyAt (pos : s.Pos) (h : pos ≠ s.endPos)
81+ | proper (needle : Slice) (table : Vector Nat needle.utf8ByteSize) (ht : table = buildTable needle)
82+ (stackPos : String.Pos.Raw) (needlePos : String.Pos.Raw) (hn : needlePos < needle.rawEndPos)
83+ | atEnd
84+ deriving Inhabited
5785
5886@[inline]
5987def iter (s : Slice) (pat : Slice) : Std.Iter (α := ForwardSliceSearcher s) (SearchStep s) :=
60- if pat.utf8ByteSize = = 0 then
88+ if h : pat.utf8ByteSize = 0 then
6189 { internalState := .emptyBefore s.startPos }
6290 else
63- { internalState := .proper pat (buildTable pat) s.startPos.offset pat.startPos.offset }
64-
65- partial def backtrackIfNecessary (pat : Slice) (table : Array String.Pos.Raw) (stackByte : UInt8)
66- (needlePos : String.Pos.Raw) : String.Pos.Raw :=
67- if needlePos != 0 && stackByte != pat.getUTF8Byte! needlePos then
68- backtrackIfNecessary pat table stackByte table[needlePos.byteIdx - 1 ]!
69- else
70- needlePos
91+ { internalState := .proper pat (buildTable pat) rfl s.startPos.offset pat.startPos.offset
92+ (by simp [Pos.Raw.lt_iff]; omega) }
7193
7294instance (s : Slice) : Std.Iterators.Iterator (ForwardSliceSearcher s) Id (SearchStep s) where
7395 IsPlausibleStep it
74- | .yield it' out =>
75- match it.internalState with
96+ | .yield it' out | .skip it' =>
97+ match it.internalState with
7698 | .emptyBefore pos => (∃ h, it'.internalState = .emptyAt pos h) ∨ it'.internalState = .atEnd
7799 | .emptyAt pos h => ∃ newPos, pos < newPos ∧ it'.internalState = .emptyBefore newPos
78- | .proper needle table stackPos needlePos =>
79- (∃ newStackPos newNeedlePos,
80- stackPos < newStackPos ∧
81- newStackPos ≤ s.rawEndPos ∧
82- it'.internalState = .proper needle table newStackPos newNeedlePos) ∨
100+ | .proper needle table ht stackPos needlePos hn =>
101+ (∃ newStackPos newNeedlePos hn ,
102+ it'.internalState = .proper needle table ht newStackPos newNeedlePos hn ∧
103+ ((s.utf8ByteSize - newStackPos.byteIdx < s.utf8ByteSize - stackPos.byteIdx) ∨
104+ ( newStackPos = stackPos ∧ newNeedlePos < needlePos)) ) ∨
83105 it'.internalState = .atEnd
84106 | .atEnd => False
85- | .skip _ => False
86107 | .done => True
87108 step := fun ⟨iter⟩ =>
88109 match iter with
@@ -95,67 +116,102 @@ instance (s : Slice) : Std.Iterators.Iterator (ForwardSliceSearcher s) Id (Searc
95116 | .emptyAt pos h =>
96117 let res := .rejected pos (pos.next h)
97118 pure (.deflate ⟨.yield ⟨.emptyBefore (pos.next h)⟩ res, by simp⟩)
98- | .proper needle table stackPos needlePos =>
99- let rec findNext (startPos : String.Pos.Raw)
100- (currStackPos : String.Pos.Raw) (needlePos : String.Pos.Raw) (h : stackPos ≤ currStackPos) :=
101- if h1 : currStackPos < s.rawEndPos then
102- let stackByte := s.getUTF8Byte currStackPos h1
103- let needlePos := backtrackIfNecessary needle table stackByte needlePos
104- let patByte := needle.getUTF8Byte! needlePos
105- if stackByte != patByte then
106- let nextStackPos := s.findNextPos currStackPos h1 |>.offset
107- let res := .rejected (s.pos! startPos) (s.pos! nextStackPos)
108- have hiter := by
109- left
110- exists nextStackPos
111- have haux := lt_offset_findNextPos h1
112- simp only [String.Pos.Raw.lt_iff, proper.injEq, true_and, exists_and_left, exists_eq', and_true,
113- nextStackPos]
114- constructor
115- · simp [String.Pos.Raw.le_iff, String.Pos.Raw.lt_iff] at h haux ⊢
116- omega
117- · apply Pos.Raw.IsValidForSlice.le_utf8ByteSize
118- apply Pos.isValidForSlice
119- .deflate ⟨.yield ⟨.proper needle table nextStackPos needlePos⟩ res, hiter⟩
119+ | .proper needle table htable stackPos needlePos hn =>
120+ -- **Invariant 1:** we have already covered everything up until `stackPos - needlePos` (exclusive),
121+ -- with matches and rejections.
122+ -- **Invariant 2:** `stackPos - needlePos` is a valid position
123+ -- **Invariant 3:** the range from from `stackPos - needlePos` to `stackPos` (exclusive) is a
124+ -- prefix of the pattern.
125+ if h₁ : stackPos < s.rawEndPos then
126+ let stackByte := s.getUTF8Byte stackPos h₁
127+ let patByte := needle.getUTF8Byte needlePos hn
128+ if stackByte = patByte then
129+ let nextStackPos := stackPos.inc
130+ let nextNeedlePos := needlePos.inc
131+ if h : nextNeedlePos = needle.rawEndPos then
132+ -- Safety: the section from `nextStackPos.descreaseBy needle.utf8ByteSize` to `nextStackPos`
133+ -- (exclusive) is exactly the needle, so it must represent a valid range.
134+ let res := .matched (s.pos! (nextStackPos.decreaseBy needle.utf8ByteSize)) (s.pos! nextStackPos)
135+ -- Invariants still satisfied
136+ pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos 0
137+ (by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
138+ by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega,
139+ Or.inl (by simp [nextStackPos, Pos.Raw.lt_iff] at h₁ ⊢; omega)⟩⟩)
120140 else
121- let needlePos := needlePos.inc
122- if needlePos == needle.rawEndPos then
123- let nextStackPos := currStackPos.inc
124- let res := .matched (s.pos! startPos) (s.pos! nextStackPos)
125- have hiter := by
126- left
127- exists nextStackPos
128- simp only [Pos.Raw.byteIdx_inc, proper.injEq, true_and, exists_and_left,
129- exists_eq', and_true, nextStackPos, String.Pos.Raw.lt_iff]
130- constructor
131- · simp [String.Pos.Raw.le_iff] at h ⊢
132- omega
133- · simp [String.Pos.Raw.le_iff, String.Pos.Raw.lt_iff] at h1 ⊢
134- omega
135- .deflate ⟨.yield ⟨.proper needle table nextStackPos 0 ⟩ res, hiter⟩
136- else
137- have hinv := by
138- simp [String.Pos.Raw.le_iff] at h ⊢
139- omega
140- findNext startPos currStackPos.inc needlePos hinv
141+ -- Invariants still satisfied
142+ pure (.deflate ⟨.skip ⟨.proper needle table htable nextStackPos nextNeedlePos
143+ (by simp [Pos.Raw.lt_iff, nextNeedlePos, Pos.Raw.ext_iff] at h hn ⊢; omega)⟩,
144+ by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [nextNeedlePos, Pos.Raw.lt_iff, Pos.Raw.ext_iff] at h hn ⊢; omega,
145+ Or.inl (by simp [nextStackPos, Pos.Raw.lt_iff] at h₁ ⊢; omega)⟩⟩)
141146 else
142- if startPos != s.rawEndPos then
143- let res := .rejected (s.pos! startPos) (s.pos! currStackPos)
144- .deflate ⟨.yield ⟨.atEnd⟩ res, by simp⟩
147+ if hnp : needlePos.byteIdx = 0 then
148+ -- Safety: by invariant 2
149+ let basePos := s.pos! stackPos
150+ -- Since we report (mis)matches by code point and not by byte, missing in the first byte
151+ -- means that we should skip ahead to the next code point.
152+ let nextStackPos := s.findNextPos stackPos h₁
153+ let res := .rejected basePos nextStackPos
154+ -- Invariants still satisfied
155+ pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos.offset 0
156+ (by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
157+ by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega,
158+ Or.inl (by
159+ have := lt_offset_findNextPos h₁
160+ have t₀ := (findNextPos _ _ h₁).isValidForSlice.le_utf8ByteSize
161+ simp [nextStackPos, Pos.Raw.lt_iff] at this ⊢; omega)⟩⟩)
145162 else
146- .deflate ⟨.done, by simp⟩
147- termination_by s.utf8ByteSize - currStackPos.byteIdx
148- decreasing_by
149- simp [String.Pos.Raw.lt_iff] at h1 ⊢
150- omega
151-
152- findNext stackPos stackPos needlePos (by simp)
163+ let newNeedlePos := table[needlePos.byteIdx - 1 ]'(by simp [Pos.Raw.lt_iff] at hn; omega)
164+ if newNeedlePos = 0 then
165+ -- Safety: by invariant 2
166+ let basePos := s.pos! (stackPos.unoffsetBy needlePos)
167+ -- Since we report (mis)matches by code point and not by byte, missing in the first byte
168+ -- means that we should skip ahead to the next code point.
169+ let nextStackPos := (s.pos? stackPos).getD (s.findNextPos stackPos h₁)
170+ let res := .rejected basePos nextStackPos
171+ -- Invariants still satisfied
172+ pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos.offset 0
173+ (by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
174+ by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega, by
175+ simp only [pos?, Pos.Raw.isValidForSlice_eq_true_iff, nextStackPos]
176+ split
177+ · exact Or.inr (by simp [Pos.Raw.lt_iff]; omega)
178+ · refine Or.inl ?_
179+ have := lt_offset_findNextPos h₁
180+ have t₀ := (findNextPos _ _ h₁).isValidForSlice.le_utf8ByteSize
181+ simp [Pos.Raw.lt_iff] at this ⊢; omega⟩⟩)
182+ else
183+ let oldBasePos := s.pos! (stackPos.decreaseBy needlePos.byteIdx)
184+ let newBasePos := s.pos! (stackPos.decreaseBy newNeedlePos)
185+ let res := .rejected oldBasePos newBasePos
186+ -- Invariants still satisfied by definition of the prefix table
187+ pure (.deflate ⟨.yield ⟨.proper needle table htable stackPos ⟨newNeedlePos⟩
188+ (by
189+ subst htable
190+ have := getElem_buildTable_le needle (needlePos.byteIdx - 1 ) (by simp [Pos.Raw.lt_iff] at hn; omega)
191+ simp [newNeedlePos, Pos.Raw.lt_iff] at hn ⊢
192+ omega)⟩ res,
193+ by
194+ simp only [proper.injEq, heq_eq_eq, true_and, exists_and_left, exists_prop,
195+ reduceCtorEq, or_false]
196+ refine ⟨_, _, ⟨rfl, rfl⟩, ?_, Or.inr ⟨rfl, ?_⟩⟩
197+ all_goals
198+ subst htable
199+ have := getElem_buildTable_le needle (needlePos.byteIdx - 1 ) (by simp [Pos.Raw.lt_iff] at hn; omega)
200+ simp [newNeedlePos, Pos.Raw.lt_iff] at hn ⊢
201+ omega⟩)
202+ else
203+ if 0 < needlePos then
204+ let basePos := stackPos.unoffsetBy needlePos
205+ let res := .rejected (s.pos! basePos) s.endPos
206+ pure (.deflate ⟨.yield ⟨.atEnd⟩ res, by simp⟩)
207+ else
208+ pure (.deflate ⟨.done, by simp⟩)
153209 | .atEnd => pure (.deflate ⟨.done, by simp⟩)
154210
155211private def toOption : ForwardSliceSearcher s → Option (Nat × Nat)
156212 | .emptyBefore pos => some (pos.remainingBytes, 1 )
157213 | .emptyAt pos _ => some (pos.remainingBytes, 0 )
158- | .proper _ _ sp _ => some (s.utf8ByteSize - sp.byteIdx, 0 )
214+ | .proper _ _ _ sp np _ => some (s.utf8ByteSize - sp.byteIdx, np.byteIdx )
159215 | .atEnd => none
160216
161217private instance : WellFoundedRelation (ForwardSliceSearcher s) where
@@ -173,7 +229,8 @@ private def finitenessRelation :
173229 simp_wf
174230 obtain ⟨step, h, h'⟩ := h
175231 cases step
176- · cases h
232+ all_goals try
233+ cases h
177234 revert h'
178235 simp only [Std.Iterators.IterM.IsPlausibleStep, Std.Iterators.Iterator.IsPlausibleStep]
179236 match it.internalState with
@@ -184,21 +241,21 @@ private def finitenessRelation :
184241 intro x hx h
185242 simpa [h, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def,
186243 ← Pos.lt_iff_remainingBytes_lt]
187- | .proper needle table stackPos needlePos =>
188- simp only [exists_and_left]
189- rintro (⟨newStackPos, h₁, h₂, ⟨x, hx⟩⟩|h)
190- · simp [hx, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, Pos.Raw.lt_iff,
191- Pos.Raw.le_iff] at ⊢ h₁ h₂
192- omega
244+ | .proper needle table ht stackPos needlePos hn =>
245+ rintro (⟨newStackPos, newNeedlePos, h₁, h₂, (h|⟨rfl, h⟩)⟩|h)
246+ · simp [h₂, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, h]
247+ · simpa [h₂, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, Pos.Raw.lt_iff]
193248 · simp [h, ForwardSliceSearcher.toOption, Option.lt]
194249 | .atEnd .. => simp
195- · cases h'
196250 · cases h
197251
198252@[no_expose]
199253instance : Std.Iterators.Finite (ForwardSliceSearcher s) Id :=
200254 .of_finitenessRelation finitenessRelation
201255
256+ instance : Std.Iterators.IteratorCollect (ForwardSliceSearcher s) Id Id :=
257+ .defaultImplementation
258+
202259instance : Std.Iterators.IteratorLoop (ForwardSliceSearcher s) Id Id :=
203260 .defaultImplementation
204261
0 commit comments