Skip to content

Commit a848cc0

Browse files
committed
fix: KMP implementation
1 parent 19533ab commit a848cc0

File tree

2 files changed

+196
-103
lines changed

2 files changed

+196
-103
lines changed

src/Init/Data/String/Pattern/String.lean

Lines changed: 160 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ public import Init.Data.String.Pattern.Basic
1010
public import Init.Data.Iterators.Internal.Termination
1111
public import Init.Data.Iterators.Consumers.Monadic.Loop
1212
import Init.Data.String.Termination
13+
public import Init.Data.Vector.Basic
1314

1415
set_option doc.verso true
1516

@@ -22,67 +23,87 @@ public section
2223

2324
namespace String.Slice.Pattern
2425

25-
inductive ForwardSliceSearcher (s : Slice) where
26-
| emptyBefore (pos : s.Pos)
27-
| emptyAt (pos : s.Pos) (h : pos ≠ s.endPos)
28-
| proper (needle : Slice) (table : Array String.Pos.Raw) (stackPos : String.Pos.Raw) (needlePos : String.Pos.Raw)
29-
| atEnd
30-
deriving Inhabited
31-
3226
namespace ForwardSliceSearcher
3327

34-
partial def buildTable (pat : Slice) : Array String.Pos.Raw :=
35-
if pat.utf8ByteSize == 0 then
36-
#[]
28+
def buildTable (pat : Slice) : Vector Nat pat.utf8ByteSize :=
29+
if h : pat.utf8ByteSize = 0 then
30+
#v[].cast h.symm
3731
else
3832
let arr := Array.emptyWithCapacity pat.utf8ByteSize
39-
let arr := arr.push 0
40-
go 1arr
33+
let arr' := arr.push 0
34+
go arr' (by simp [arr']) (by simp [arr', arr]; omega) (by simp [arr', arr])
4135
where
42-
go (pos : String.Pos.Raw) (table : Array String.Pos.Raw) :=
43-
if h : pos < pat.rawEndPos then
44-
let patByte := pat.getUTF8Byte pos h
45-
let distance := computeDistance table[table.size - 1]! patByte table
46-
let distance := if patByte = pat.getUTF8Byte! distance then distance.inc else distance
47-
go pos.inc (table.push distance)
36+
go (table : Array Nat) (ht₀ : 0 < table.size) (ht : table.size ≤ pat.utf8ByteSize) (h : ∀ (i : Nat) hi, table[i]'hi ≤ i) :
37+
Vector Nat pat.utf8ByteSize :=
38+
if hs : table.size < pat.utf8ByteSize then
39+
let patByte := pat.getUTF8Byte ⟨table.size⟩ hs
40+
let dist := computeDistance patByte table ht h (table[table.size - 1])
41+
(by have := h (table.size - 1) (by omega); omega)
42+
let dist' := if pat.getUTF8Byte ⟨dist.1⟩ (by simp [Pos.Raw.lt_iff]; omega) = patByte then dist.1 + 1 else dist
43+
go (table.push dist') (by simp) (by simp; omega) (by
44+
intro i hi
45+
by_cases hi' : i = table.size
46+
· subst hi'
47+
simp [dist']
48+
have := dist.2
49+
split <;> omega
50+
· rw [Array.getElem_push_lt]
51+
· apply h
52+
· simp at hi
53+
omega)
4854
else
49-
table
50-
51-
computeDistance (distance : String.Pos.Raw) (patByte : UInt8) (table : Array String.Pos.Raw) :
52-
String.Pos.Raw :=
53-
if distance > 0 && patByte != pat.getUTF8Byte! distance then
54-
computeDistance table[distance.byteIdx - 1]! patByte table
55+
Vector.mk table (by omega)
56+
57+
computeDistance (patByte : UInt8) (table : Array Nat)
58+
(ht : table.size ≤ pat.utf8ByteSize)
59+
(h : ∀ (i : Nat) hi, table[i]'hi ≤ i) (guess : Nat) (hg : guess < table.size) :
60+
{ n : Nat // n < table.size } :=
61+
if h' : guess = 0 ∨ pat.getUTF8Byte ⟨guess⟩ (by simp [Pos.Raw.lt_iff]; omega) = patByte then
62+
⟨guess, hg⟩
5563
else
56-
distance
64+
have : table[guess - 1] < guess := by have := h (guess - 1) (by omega); omega
65+
computeDistance patByte table ht h table[guess - 1] (by omega)
66+
67+
theorem getElem_buildTable_le (pat : Slice) (i : Nat) (hi) : (buildTable pat)[i]'hi ≤ i := by
68+
rw [buildTable]
69+
split <;> rename_i h
70+
· simp [h] at hi
71+
· simp only [Array.emptyWithCapacity_eq, List.push_toArray, List.nil_append]
72+
suffices ∀ pat' table ht₀ ht h (i : Nat) hi, (buildTable.go pat' table ht₀ ht h)[i]'hi ≤ i from this ..
73+
intro pat' table ht₀ ht h i hi
74+
fun_induction buildTable.go with
75+
| case1 => assumption
76+
| case2 table ht₀ ht ht' ht'' => apply ht'
77+
78+
inductive _root_.String.Slice.Pattern.ForwardSliceSearcher (s : Slice) where
79+
| emptyBefore (pos : s.Pos)
80+
| emptyAt (pos : s.Pos) (h : pos ≠ s.endPos)
81+
| proper (needle : Slice) (table : Vector Nat needle.utf8ByteSize) (ht : table = buildTable needle)
82+
(stackPos : String.Pos.Raw) (needlePos : String.Pos.Raw) (hn : needlePos < needle.rawEndPos)
83+
| atEnd
84+
deriving Inhabited
5785

5886
@[inline]
5987
def iter (s : Slice) (pat : Slice) : Std.Iter (α := ForwardSliceSearcher s) (SearchStep s) :=
60-
if pat.utf8ByteSize == 0 then
88+
if h : pat.utf8ByteSize = 0 then
6189
{ internalState := .emptyBefore s.startPos }
6290
else
63-
{ internalState := .proper pat (buildTable pat) s.startPos.offset pat.startPos.offset }
64-
65-
partial def backtrackIfNecessary (pat : Slice) (table : Array String.Pos.Raw) (stackByte : UInt8)
66-
(needlePos : String.Pos.Raw) : String.Pos.Raw :=
67-
if needlePos != 0 && stackByte != pat.getUTF8Byte! needlePos then
68-
backtrackIfNecessary pat table stackByte table[needlePos.byteIdx - 1]!
69-
else
70-
needlePos
91+
{ internalState := .proper pat (buildTable pat) rfl s.startPos.offset pat.startPos.offset
92+
(by simp [Pos.Raw.lt_iff]; omega) }
7193

7294
instance (s : Slice) : Std.Iterators.Iterator (ForwardSliceSearcher s) Id (SearchStep s) where
7395
IsPlausibleStep it
74-
| .yield it' out =>
75-
match it.internalState with
96+
| .yield it' out | .skip it' =>
97+
match it.internalState with
7698
| .emptyBefore pos => (∃ h, it'.internalState = .emptyAt pos h) ∨ it'.internalState = .atEnd
7799
| .emptyAt pos h => ∃ newPos, pos < newPos ∧ it'.internalState = .emptyBefore newPos
78-
| .proper needle table stackPos needlePos =>
79-
(∃ newStackPos newNeedlePos,
80-
stackPos < newStackPos ∧
81-
newStackPos s.rawEndPos ∧
82-
it'.internalState = .proper needle table newStackPos newNeedlePos) ∨
100+
| .proper needle table ht stackPos needlePos hn =>
101+
(∃ newStackPos newNeedlePos hn,
102+
it'.internalState = .proper needle table ht newStackPos newNeedlePos hn
103+
((s.utf8ByteSize - newStackPos.byteIdx < s.utf8ByteSize - stackPos.byteIdx) ∨
104+
(newStackPos = stackPos ∧ newNeedlePos < needlePos))) ∨
83105
it'.internalState = .atEnd
84106
| .atEnd => False
85-
| .skip _ => False
86107
| .done => True
87108
step := fun ⟨iter⟩ =>
88109
match iter with
@@ -95,67 +116,102 @@ instance (s : Slice) : Std.Iterators.Iterator (ForwardSliceSearcher s) Id (Searc
95116
| .emptyAt pos h =>
96117
let res := .rejected pos (pos.next h)
97118
pure (.deflate ⟨.yield ⟨.emptyBefore (pos.next h)⟩ res, by simp⟩)
98-
| .proper needle table stackPos needlePos =>
99-
let rec findNext (startPos : String.Pos.Raw)
100-
(currStackPos : String.Pos.Raw) (needlePos : String.Pos.Raw) (h : stackPos ≤ currStackPos) :=
101-
if h1 : currStackPos < s.rawEndPos then
102-
let stackByte := s.getUTF8Byte currStackPos h1
103-
let needlePos := backtrackIfNecessary needle table stackByte needlePos
104-
let patByte := needle.getUTF8Byte! needlePos
105-
if stackByte != patByte then
106-
let nextStackPos := s.findNextPos currStackPos h1 |>.offset
107-
let res := .rejected (s.pos! startPos) (s.pos! nextStackPos)
108-
have hiter := by
109-
left
110-
exists nextStackPos
111-
have haux := lt_offset_findNextPos h1
112-
simp only [String.Pos.Raw.lt_iff, proper.injEq, true_and, exists_and_left, exists_eq', and_true,
113-
nextStackPos]
114-
constructor
115-
· simp [String.Pos.Raw.le_iff, String.Pos.Raw.lt_iff] at h haux ⊢
116-
omega
117-
· apply Pos.Raw.IsValidForSlice.le_utf8ByteSize
118-
apply Pos.isValidForSlice
119-
.deflate ⟨.yield ⟨.proper needle table nextStackPos needlePos⟩ res, hiter⟩
119+
| .proper needle table htable stackPos needlePos hn =>
120+
-- **Invariant 1:** we have already covered everything up until `stackPos - needlePos` (exclusive),
121+
-- with matches and rejections.
122+
-- **Invariant 2:** `stackPos - needlePos` is a valid position
123+
-- **Invariant 3:** the range from from `stackPos - needlePos` to `stackPos` (exclusive) is a
124+
-- prefix of the pattern.
125+
if h₁ : stackPos < s.rawEndPos then
126+
let stackByte := s.getUTF8Byte stackPos h₁
127+
let patByte := needle.getUTF8Byte needlePos hn
128+
if stackByte = patByte then
129+
let nextStackPos := stackPos.inc
130+
let nextNeedlePos := needlePos.inc
131+
if h : nextNeedlePos = needle.rawEndPos then
132+
-- Safety: the section from `nextStackPos.descreaseBy needle.utf8ByteSize` to `nextStackPos`
133+
-- (exclusive) is exactly the needle, so it must represent a valid range.
134+
let res := .matched (s.pos! (nextStackPos.decreaseBy needle.utf8ByteSize)) (s.pos! nextStackPos)
135+
-- Invariants still satisfied
136+
pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos 0
137+
(by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
138+
by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega,
139+
Or.inl (by simp [nextStackPos, Pos.Raw.lt_iff] at h₁ ⊢; omega)⟩⟩)
120140
else
121-
let needlePos := needlePos.inc
122-
if needlePos == needle.rawEndPos then
123-
let nextStackPos := currStackPos.inc
124-
let res := .matched (s.pos! startPos) (s.pos! nextStackPos)
125-
have hiter := by
126-
left
127-
exists nextStackPos
128-
simp only [Pos.Raw.byteIdx_inc, proper.injEq, true_and, exists_and_left,
129-
exists_eq', and_true, nextStackPos, String.Pos.Raw.lt_iff]
130-
constructor
131-
· simp [String.Pos.Raw.le_iff] at h ⊢
132-
omega
133-
· simp [String.Pos.Raw.le_iff, String.Pos.Raw.lt_iff] at h1 ⊢
134-
omega
135-
.deflate ⟨.yield ⟨.proper needle table nextStackPos 0⟩ res, hiter⟩
136-
else
137-
have hinv := by
138-
simp [String.Pos.Raw.le_iff] at h ⊢
139-
omega
140-
findNext startPos currStackPos.inc needlePos hinv
141+
-- Invariants still satisfied
142+
pure (.deflate ⟨.skip ⟨.proper needle table htable nextStackPos nextNeedlePos
143+
(by simp [Pos.Raw.lt_iff, nextNeedlePos, Pos.Raw.ext_iff] at h hn ⊢; omega)⟩,
144+
by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [nextNeedlePos, Pos.Raw.lt_iff, Pos.Raw.ext_iff] at h hn ⊢; omega,
145+
Or.inl (by simp [nextStackPos, Pos.Raw.lt_iff] at h₁ ⊢; omega)⟩⟩)
141146
else
142-
if startPos != s.rawEndPos then
143-
let res := .rejected (s.pos! startPos) (s.pos! currStackPos)
144-
.deflate ⟨.yield ⟨.atEnd⟩ res, by simp⟩
147+
if hnp : needlePos.byteIdx = 0 then
148+
-- Safety: by invariant 2
149+
let basePos := s.pos! stackPos
150+
-- Since we report (mis)matches by code point and not by byte, missing in the first byte
151+
-- means that we should skip ahead to the next code point.
152+
let nextStackPos := s.findNextPos stackPos h₁
153+
let res := .rejected basePos nextStackPos
154+
-- Invariants still satisfied
155+
pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos.offset 0
156+
(by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
157+
by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega,
158+
Or.inl (by
159+
have := lt_offset_findNextPos h₁
160+
have t₀ := (findNextPos _ _ h₁).isValidForSlice.le_utf8ByteSize
161+
simp [nextStackPos, Pos.Raw.lt_iff] at this ⊢; omega)⟩⟩)
145162
else
146-
.deflate ⟨.done, by simp⟩
147-
termination_by s.utf8ByteSize - currStackPos.byteIdx
148-
decreasing_by
149-
simp [String.Pos.Raw.lt_iff] at h1 ⊢
150-
omega
151-
152-
findNext stackPos stackPos needlePos (by simp)
163+
let newNeedlePos := table[needlePos.byteIdx - 1]'(by simp [Pos.Raw.lt_iff] at hn; omega)
164+
if newNeedlePos = 0 then
165+
-- Safety: by invariant 2
166+
let basePos := s.pos! (stackPos.unoffsetBy needlePos)
167+
-- Since we report (mis)matches by code point and not by byte, missing in the first byte
168+
-- means that we should skip ahead to the next code point.
169+
let nextStackPos := (s.pos? stackPos).getD (s.findNextPos stackPos h₁)
170+
let res := .rejected basePos nextStackPos
171+
-- Invariants still satisfied
172+
pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos.offset 0
173+
(by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
174+
by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega, by
175+
simp only [pos?, Pos.Raw.isValidForSlice_eq_true_iff, nextStackPos]
176+
split
177+
· exact Or.inr (by simp [Pos.Raw.lt_iff]; omega)
178+
· refine Or.inl ?_
179+
have := lt_offset_findNextPos h₁
180+
have t₀ := (findNextPos _ _ h₁).isValidForSlice.le_utf8ByteSize
181+
simp [Pos.Raw.lt_iff] at this ⊢; omega⟩⟩)
182+
else
183+
let oldBasePos := s.pos! (stackPos.decreaseBy needlePos.byteIdx)
184+
let newBasePos := s.pos! (stackPos.decreaseBy newNeedlePos)
185+
let res := .rejected oldBasePos newBasePos
186+
-- Invariants still satisfied by definition of the prefix table
187+
pure (.deflate ⟨.yield ⟨.proper needle table htable stackPos ⟨newNeedlePos⟩
188+
(by
189+
subst htable
190+
have := getElem_buildTable_le needle (needlePos.byteIdx - 1) (by simp [Pos.Raw.lt_iff] at hn; omega)
191+
simp [newNeedlePos, Pos.Raw.lt_iff] at hn ⊢
192+
omega)⟩ res,
193+
by
194+
simp only [proper.injEq, heq_eq_eq, true_and, exists_and_left, exists_prop,
195+
reduceCtorEq, or_false]
196+
refine ⟨_, _, ⟨rfl, rfl⟩, ?_, Or.inr ⟨rfl, ?_⟩⟩
197+
all_goals
198+
subst htable
199+
have := getElem_buildTable_le needle (needlePos.byteIdx - 1) (by simp [Pos.Raw.lt_iff] at hn; omega)
200+
simp [newNeedlePos, Pos.Raw.lt_iff] at hn ⊢
201+
omega⟩)
202+
else
203+
if 0 < needlePos then
204+
let basePos := stackPos.unoffsetBy needlePos
205+
let res := .rejected (s.pos! basePos) s.endPos
206+
pure (.deflate ⟨.yield ⟨.atEnd⟩ res, by simp⟩)
207+
else
208+
pure (.deflate ⟨.done, by simp⟩)
153209
| .atEnd => pure (.deflate ⟨.done, by simp⟩)
154210

155211
private def toOption : ForwardSliceSearcher s → Option (Nat × Nat)
156212
| .emptyBefore pos => some (pos.remainingBytes, 1)
157213
| .emptyAt pos _ => some (pos.remainingBytes, 0)
158-
| .proper _ _ sp _ => some (s.utf8ByteSize - sp.byteIdx, 0)
214+
| .proper _ _ _ sp np _ => some (s.utf8ByteSize - sp.byteIdx, np.byteIdx)
159215
| .atEnd => none
160216

161217
private instance : WellFoundedRelation (ForwardSliceSearcher s) where
@@ -173,7 +229,8 @@ private def finitenessRelation :
173229
simp_wf
174230
obtain ⟨step, h, h'⟩ := h
175231
cases step
176-
· cases h
232+
all_goals try
233+
cases h
177234
revert h'
178235
simp only [Std.Iterators.IterM.IsPlausibleStep, Std.Iterators.Iterator.IsPlausibleStep]
179236
match it.internalState with
@@ -184,21 +241,21 @@ private def finitenessRelation :
184241
intro x hx h
185242
simpa [h, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def,
186243
← Pos.lt_iff_remainingBytes_lt]
187-
| .proper needle table stackPos needlePos =>
188-
simp only [exists_and_left]
189-
rintro (⟨newStackPos, h₁, h₂, ⟨x, hx⟩⟩|h)
190-
· simp [hx, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, Pos.Raw.lt_iff,
191-
Pos.Raw.le_iff] at ⊢ h₁ h₂
192-
omega
244+
| .proper needle table ht stackPos needlePos hn =>
245+
rintro (⟨newStackPos, newNeedlePos, h₁, h₂, (h|⟨rfl, h⟩)⟩|h)
246+
· simp [h₂, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, h]
247+
· simpa [h₂, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, Pos.Raw.lt_iff]
193248
· simp [h, ForwardSliceSearcher.toOption, Option.lt]
194249
| .atEnd .. => simp
195-
· cases h'
196250
· cases h
197251

198252
@[no_expose]
199253
instance : Std.Iterators.Finite (ForwardSliceSearcher s) Id :=
200254
.of_finitenessRelation finitenessRelation
201255

256+
instance : Std.Iterators.IteratorCollect (ForwardSliceSearcher s) Id Id :=
257+
.defaultImplementation
258+
202259
instance : Std.Iterators.IteratorLoop (ForwardSliceSearcher s) Id Id :=
203260
.defaultImplementation
204261

tests/lean/run/string_kmp.lean

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
module
2+
3+
inductive S where
4+
| m (b e : Nat)
5+
| r (b e : Nat)
6+
deriving Repr, BEq, DecidableEq
7+
8+
def run [String.ToSlice α] [String.ToSlice β] (s : α) (pat : β) : List S :=
9+
String.Slice.Pattern.ForwardSliceSearcher.iter (String.ToSlice.toSlice s) (String.ToSlice.toSlice pat)
10+
|>.map (fun | .matched b e => S.m b.offset.byteIdx e.offset.byteIdx | .rejected b e => S.r b.offset.byteIdx e.offset.byteIdx)
11+
|>.toList
12+
13+
-- 𝔸 is [240,157,148,184]
14+
-- 𝕸 is [240,157,149,184]
15+
16+
#guard run "aababaab" "a" = [.m 0 1, .m 1 2, .r 2 3, .m 3 4, .r 4 5, .m 5 6, .m 6 7, .r 7 8]
17+
#guard run "aab" "ab" = [.r 0 1, .m 1 3]
18+
#guard run "aababacab" "ab" = [.r 0 1, .m 1 3, .m 3 5, .r 5 6, .r 6 7, .m 7 9]
19+
#guard run "aaab" "aab" = [.r 0 1, .m 1 4]
20+
#guard run "aaaaa" "aa" = [.m 0 2, .m 2 4, .r 4 5]
21+
#guard run "abcabd" "abd" = [.r 0 2, .r 2 3, .m 3 6]
22+
#guard run "αβ" "β" = [.r 0 2, .m 2 4]
23+
#guard run "𝔸" "𝕸" = [.r 0 4]
24+
#guard run "𝔸𝕸" "𝕸" = [.r 0 4, .m 4 8]
25+
#guard run "α𝔸€α𝔸₭" "α𝔸₭" = [.r 0 9, .m 9 18]
26+
#guard run "α𝔸𝕸α𝔸₭" "α𝔸₭" = [.r 0 6, .r 6 10, .m 10 19]
27+
#guard run "𝕸𝔸𝕸𝔸₭" "𝕸𝔸₭" = [.r 0 8, .m 8 19]
28+
#guard run "𝕸𝔸𝕸β₭" "𝕸𝔸₭" = [.r 0 8, .r 8 12, .r 12 14, .r 14 17]
29+
#guard run "𝔸𝔸𝔸𝔸𝕸𝔸𝔸𝔸𝕸" "𝔸𝔸𝕸" = [.r 0 4, .r 4 8, .m 8 20, .r 20 24, .m 24 36]
30+
#guard run "𝔸b" "𝕸" = [.r 0 4, .r 4 5]
31+
#guard run "𝔸bb𝕸β" "𝕸" = [.r 0 4, .r 4 5, .r 5 6, .m 6 10, .r 10 12]
32+
#guard run "𝔸bbββαβαββββ𝕸β" "ββ𝕸" = [.r 0 4, .r 4 5, .r 5 6, .r 6 8, .r 8 10, .r 10 12, .r 12 14, .r 14 16, .r 16 18, .r 18 20, .m 20 28, .r 28 30]
33+
#guard run "𝔸β𝕸" "𝕸" = [.r 0 4, .r 4 6, .m 6 10]
34+
#guard run "𝔸b𝕸xu∅" "𝕸x" = [.r 0 4, .r 4 5, .m 5 10, .r 10 11, .r 11 14]
35+
#guard run "é" "ù" = [.r 0 2]
36+
#guard run "éB" "ù" = [.r 0 2, .r 2 3]

0 commit comments

Comments
 (0)