Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Init/Data/String/Pattern/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ inductive SearchStep (s : Slice) where
The subslice starting at {name}`startPos` and ending at {name}`endPos` did not match the pattern.
-/
| matched (startPos endPos : s.Pos)
deriving Inhabited
deriving Inhabited, BEq

/--
Provides a conversion from a pattern to an iterator of {name}`SearchStep` that searches for matches
Expand Down
318 changes: 188 additions & 130 deletions src/Init/Data/String/Pattern/String.lean
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ prelude
public import Init.Data.String.Pattern.Basic
public import Init.Data.Iterators.Internal.Termination
public import Init.Data.Iterators.Consumers.Monadic.Loop
public import Init.Data.Vector.Basic

set_option doc.verso true

Expand All @@ -21,142 +22,202 @@ public section

namespace String.Slice.Pattern

inductive ForwardSliceSearcher (s : Slice) where
| empty (pos : s.Pos)
| proper (needle : Slice) (table : Array String.Pos.Raw) (stackPos : String.Pos.Raw) (needlePos : String.Pos.Raw)
| atEnd
deriving Inhabited

namespace ForwardSliceSearcher

partial def buildTable (pat : Slice) : Array String.Pos.Raw :=
if pat.utf8ByteSize == 0 then
#[]
def buildTable (pat : Slice) : Vector Nat pat.utf8ByteSize :=
if h : pat.utf8ByteSize = 0 then
#v[].cast h.symm
else
let arr := Array.emptyWithCapacity pat.utf8ByteSize
let arr := arr.push 0
go ⟨1⟩ arr
let arr' := arr.push 0
go arr' (by simp [arr']) (by simp [arr', arr]; omega) (by simp [arr', arr])
where
go (pos : String.Pos.Raw) (table : Array String.Pos.Raw) :=
if h : pos < pat.rawEndPos then
let patByte := pat.getUTF8Byte pos h
let distance := computeDistance table[table.size - 1]! patByte table
let distance := if patByte = pat.getUTF8Byte! distance then distance.inc else distance
go pos.inc (table.push distance)
go (table : Array Nat) (ht₀ : 0 < table.size) (ht : table.size ≤ pat.utf8ByteSize) (h : ∀ (i : Nat) hi, table[i]'hi ≤ i) :
Vector Nat pat.utf8ByteSize :=
if hs : table.size < pat.utf8ByteSize then
let patByte := pat.getUTF8Byte ⟨table.size⟩ hs
let dist := computeDistance patByte table ht h (table[table.size - 1])
(by have := h (table.size - 1) (by omega); omega)
let dist' := if pat.getUTF8Byte ⟨dist.1⟩ (by simp [Pos.Raw.lt_iff]; omega) = patByte then dist.1 + 1 else dist
go (table.push dist') (by simp) (by simp; omega) (by
intro i hi
by_cases hi' : i = table.size
· subst hi'
simp [dist']
have := dist.2
split <;> omega
· rw [Array.getElem_push_lt]
· apply h
· simp at hi
omega)
else
table

computeDistance (distance : String.Pos.Raw) (patByte : UInt8) (table : Array String.Pos.Raw) :
String.Pos.Raw :=
if distance > 0 && patByte != pat.getUTF8Byte! distance then
computeDistance table[distance.byteIdx - 1]! patByte table
Vector.mk table (by omega)

computeDistance (patByte : UInt8) (table : Array Nat)
(ht : table.size ≤ pat.utf8ByteSize)
(h : ∀ (i : Nat) hi, table[i]'hi ≤ i) (guess : Nat) (hg : guess < table.size) :
{ n : Nat // n < table.size } :=
if h' : guess = 0 ∨ pat.getUTF8Byte ⟨guess⟩ (by simp [Pos.Raw.lt_iff]; omega) = patByte then
⟨guess, hg⟩
else
distance
have : table[guess - 1] < guess := by have := h (guess - 1) (by omega); omega
computeDistance patByte table ht h table[guess - 1] (by omega)

theorem getElem_buildTable_le (pat : Slice) (i : Nat) (hi) : (buildTable pat)[i]'hi ≤ i := by
rw [buildTable]
split <;> rename_i h
· simp [h] at hi
· simp only [Array.emptyWithCapacity_eq, List.push_toArray, List.nil_append]
suffices ∀ pat' table ht₀ ht h (i : Nat) hi, (buildTable.go pat' table ht₀ ht h)[i]'hi ≤ i from this ..
intro pat' table ht₀ ht h i hi
fun_induction buildTable.go with
| case1 => assumption
| case2 table ht₀ ht ht' ht'' => apply ht'

inductive _root_.String.Slice.Pattern.ForwardSliceSearcher (s : Slice) where
| emptyBefore (pos : s.Pos)
| emptyAt (pos : s.Pos) (h : pos ≠ s.endPos)
| proper (needle : Slice) (table : Vector Nat needle.utf8ByteSize) (ht : table = buildTable needle)
(stackPos : String.Pos.Raw) (needlePos : String.Pos.Raw) (hn : needlePos < needle.rawEndPos)
| atEnd
deriving Inhabited

@[inline]
def iter (s : Slice) (pat : Slice) : Std.Iter (α := ForwardSliceSearcher s) (SearchStep s) :=
if pat.utf8ByteSize == 0 then
{ internalState := .empty s.startPos }
else
{ internalState := .proper pat (buildTable pat) s.startPos.offset pat.startPos.offset }

partial def backtrackIfNecessary (pat : Slice) (table : Array String.Pos.Raw) (stackByte : UInt8)
(needlePos : String.Pos.Raw) : String.Pos.Raw :=
if needlePos != 0 && stackByte != pat.getUTF8Byte! needlePos then
backtrackIfNecessary pat table stackByte table[needlePos.byteIdx - 1]!
if h : pat.utf8ByteSize = 0 then
{ internalState := .emptyBefore s.startPos }
else
needlePos
{ internalState := .proper pat (buildTable pat) rfl s.startPos.offset pat.startPos.offset
(by simp [Pos.Raw.lt_iff]; omega) }

instance (s : Slice) : Std.Iterators.Iterator (ForwardSliceSearcher s) Id (SearchStep s) where
IsPlausibleStep it
| .yield it' out =>
match it.internalState with
| .empty pos =>
(∃ newPos, pos < newPos ∧ it'.internalState = .empty newPos) ∨
it'.internalState = .atEnd
| .proper needle table stackPos needlePos =>
(∃ newStackPos newNeedlePos,
stackPos < newStackPos ∧
newStackPos ≤ s.rawEndPos ∧
it'.internalState = .proper needle table newStackPos newNeedlePos) ∨
| .yield it' out | .skip it' =>
match it.internalState with
| .emptyBefore pos => (∃ h, it'.internalState = .emptyAt pos h) ∨ it'.internalState = .atEnd
| .emptyAt pos h => ∃ newPos, pos < newPos ∧ it'.internalState = .emptyBefore newPos
| .proper needle table ht stackPos needlePos hn =>
(∃ newStackPos newNeedlePos hn,
it'.internalState = .proper needle table ht newStackPos newNeedlePos hn ∧
((s.utf8ByteSize - newStackPos.byteIdx < s.utf8ByteSize - stackPos.byteIdx) ∨
(newStackPos = stackPos ∧ newNeedlePos < needlePos))) ∨
it'.internalState = .atEnd
| .atEnd => False
| .skip _ => False
| .done => True
step := fun ⟨iter⟩ =>
match iter with
| .empty pos =>
| .emptyBefore pos =>
let res := .matched pos pos
if h : pos ≠ s.endPos then
pure (.deflate ⟨.yield ⟨.empty (pos.next h)⟩ res, by simp⟩)
pure (.deflate ⟨.yield ⟨.emptyAt pos h⟩ res, by simp [h]⟩)
else
pure (.deflate ⟨.yield ⟨.atEnd⟩ res, by simp⟩)
| .proper needle table stackPos needlePos =>
let rec findNext (startPos : String.Pos.Raw)
(currStackPos : String.Pos.Raw) (needlePos : String.Pos.Raw) (h : stackPos ≤ currStackPos) :=
if h1 : currStackPos < s.rawEndPos then
let stackByte := s.getUTF8Byte currStackPos h1
let needlePos := backtrackIfNecessary needle table stackByte needlePos
let patByte := needle.getUTF8Byte! needlePos
if stackByte != patByte then
let nextStackPos := s.findNextPos currStackPos h1 |>.offset
let res := .rejected (s.pos! startPos) (s.pos! nextStackPos)
have hiter := by
left
exists nextStackPos
have haux := lt_offset_findNextPos h1
simp only [String.Pos.Raw.lt_iff, proper.injEq, true_and, exists_and_left, exists_eq', and_true,
nextStackPos]
constructor
· simp [String.Pos.Raw.le_iff, String.Pos.Raw.lt_iff] at h haux ⊢
omega
· apply Pos.Raw.IsValidForSlice.le_utf8ByteSize
apply Pos.isValidForSlice
.deflate ⟨.yield ⟨.proper needle table nextStackPos needlePos⟩ res, hiter⟩
| .emptyAt pos h =>
let res := .rejected pos (pos.next h)
pure (.deflate ⟨.yield ⟨.emptyBefore (pos.next h)⟩ res, by simp⟩)
| .proper needle table htable stackPos needlePos hn =>
-- **Invariant 1:** we have already covered everything up until `stackPos - needlePos` (exclusive),
-- with matches and rejections.
-- **Invariant 2:** `stackPos - needlePos` is a valid position
-- **Invariant 3:** the range from from `stackPos - needlePos` to `stackPos` (exclusive) is a
-- prefix of the pattern.
if h₁ : stackPos < s.rawEndPos then
let stackByte := s.getUTF8Byte stackPos h₁
let patByte := needle.getUTF8Byte needlePos hn
if stackByte = patByte then
let nextStackPos := stackPos.inc
let nextNeedlePos := needlePos.inc
if h : nextNeedlePos = needle.rawEndPos then
-- Safety: the section from `nextStackPos.decreaseBy needle.utf8ByteSize` to `nextStackPos`
-- (exclusive) is exactly the needle, so it must represent a valid range.
let res := .matched (s.pos! (nextStackPos.decreaseBy needle.utf8ByteSize)) (s.pos! nextStackPos)
-- Invariants still satisfied
pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos 0
(by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega,
Or.inl (by simp [nextStackPos, Pos.Raw.lt_iff] at h₁ ⊢; omega)⟩⟩)
else
let needlePos := needlePos.inc
if needlePos == needle.rawEndPos then
let nextStackPos := currStackPos.inc
let res := .matched (s.pos! startPos) (s.pos! nextStackPos)
have hiter := by
left
exists nextStackPos
simp only [Pos.Raw.byteIdx_inc, proper.injEq, true_and, exists_and_left,
exists_eq', and_true, nextStackPos, String.Pos.Raw.lt_iff]
constructor
· simp [String.Pos.Raw.le_iff] at h ⊢
omega
· simp [String.Pos.Raw.le_iff, String.Pos.Raw.lt_iff] at h1 ⊢
omega
.deflate ⟨.yield ⟨.proper needle table nextStackPos 0⟩ res, hiter⟩
else
have hinv := by
simp [String.Pos.Raw.le_iff] at h ⊢
omega
findNext startPos currStackPos.inc needlePos hinv
-- Invariants still satisfied
pure (.deflate ⟨.skip ⟨.proper needle table htable nextStackPos nextNeedlePos
(by simp [Pos.Raw.lt_iff, nextNeedlePos, Pos.Raw.ext_iff] at h hn ⊢; omega)⟩,
by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [nextNeedlePos, Pos.Raw.lt_iff, Pos.Raw.ext_iff] at h hn ⊢; omega,
Or.inl (by simp [nextStackPos, Pos.Raw.lt_iff] at h₁ ⊢; omega)⟩⟩)
else
if startPos != s.rawEndPos then
let res := .rejected (s.pos! startPos) (s.pos! currStackPos)
.deflate ⟨.yield ⟨.atEnd⟩ res, by simp⟩
if hnp : needlePos.byteIdx = 0 then
-- Safety: by invariant 2
let basePos := s.pos! stackPos
-- Since we report (mis)matches by code point and not by byte, missing in the first byte
-- means that we should skip ahead to the next code point.
let nextStackPos := s.findNextPos stackPos h₁
let res := .rejected basePos nextStackPos
-- Invariants still satisfied
pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos.offset 0
(by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega,
Or.inl (by
have := lt_offset_findNextPos h₁
have t₀ := (findNextPos _ _ h₁).isValidForSlice.le_utf8ByteSize
simp [nextStackPos, Pos.Raw.lt_iff, Pos.Raw.le_iff] at this t₀ ⊢; omega)⟩⟩)
else
.deflate ⟨.done, by simp⟩
termination_by s.utf8ByteSize - currStackPos.byteIdx
decreasing_by
simp [String.Pos.Raw.lt_iff] at h1 ⊢
omega

findNext stackPos stackPos needlePos (by simp)
let newNeedlePos := table[needlePos.byteIdx - 1]'(by simp [Pos.Raw.lt_iff] at hn; omega)
if newNeedlePos = 0 then
-- Safety: by invariant 2
let basePos := s.pos! (stackPos.unoffsetBy needlePos)
-- Since we report (mis)matches by code point and not by byte, missing in the first byte
-- means that we should skip ahead to the next code point.
let nextStackPos := (s.pos? stackPos).getD (s.findNextPos stackPos h₁)
let res := .rejected basePos nextStackPos
-- Invariants still satisfied
pure (.deflate ⟨.yield ⟨.proper needle table htable nextStackPos.offset 0
(by simp [Pos.Raw.lt_iff] at hn ⊢; omega)⟩ res,
by simpa using ⟨_, _, ⟨rfl, rfl⟩, by simp [Pos.Raw.lt_iff] at hn ⊢; omega, by
simp only [pos?, Pos.Raw.isValidForSlice_eq_true_iff, nextStackPos]
split
· exact Or.inr (by simp [Pos.Raw.lt_iff]; omega)
· refine Or.inl ?_
have := lt_offset_findNextPos h₁
have t₀ := (findNextPos _ _ h₁).isValidForSlice.le_utf8ByteSize
simp [Pos.Raw.lt_iff, Pos.Raw.le_iff] at this t₀ ⊢; omega⟩⟩)
else
let oldBasePos := s.pos! (stackPos.decreaseBy needlePos.byteIdx)
let newBasePos := s.pos! (stackPos.decreaseBy newNeedlePos)
let res := .rejected oldBasePos newBasePos
-- Invariants still satisfied by definition of the prefix table
pure (.deflate ⟨.yield ⟨.proper needle table htable stackPos ⟨newNeedlePos⟩
(by
subst htable
have := getElem_buildTable_le needle (needlePos.byteIdx - 1) (by simp [Pos.Raw.lt_iff] at hn; omega)
simp [newNeedlePos, Pos.Raw.lt_iff] at hn ⊢
omega)⟩ res,
by
simp only [proper.injEq, heq_eq_eq, true_and, exists_and_left, exists_prop,
reduceCtorEq, or_false]
refine ⟨_, _, ⟨rfl, rfl⟩, ?_, Or.inr ⟨rfl, ?_⟩⟩
all_goals
subst htable
have := getElem_buildTable_le needle (needlePos.byteIdx - 1) (by simp [Pos.Raw.lt_iff] at hn; omega)
simp [newNeedlePos, Pos.Raw.lt_iff] at hn ⊢
omega⟩)
else
if 0 < needlePos then
let basePos := stackPos.unoffsetBy needlePos
let res := .rejected (s.pos! basePos) s.endPos
pure (.deflate ⟨.yield ⟨.atEnd⟩ res, by simp⟩)
else
pure (.deflate ⟨.done, by simp⟩)
| .atEnd => pure (.deflate ⟨.done, by simp⟩)

private def toPair : ForwardSliceSearcher s → (Nat × Nat)
| .empty pos => (1, s.utf8ByteSize - pos.offset.byteIdx)
| .proper _ _ sp _ => (1, s.utf8ByteSize - sp.byteIdx)
| .atEnd => (0, 0)
private def toOption : ForwardSliceSearcher s → Option (Nat × Nat)
| .emptyBefore pos => some (s.utf8ByteSize - pos.offset.byteIdx, 1)
| .emptyAt pos _ => some (s.utf8ByteSize - pos.offset.byteIdx, 0)
| .proper _ _ _ sp np _ => some (s.utf8ByteSize - sp.byteIdx, np.byteIdx)
| .atEnd => none

private instance : WellFoundedRelation (ForwardSliceSearcher s) where
rel s1 s2 := Prod.Lex (· < ·) (· < ·) s1.toPair s2.toPair
rel := InvImage (Option.lt (Prod.Lex (· < ·) (· < ·))) ForwardSliceSearcher.toOption
wf := by
apply InvImage.wf
apply Option.wellFounded_lt
apply (Prod.lex _ _).wf

private def finitenessRelation :
Expand All @@ -167,38 +228,35 @@ private def finitenessRelation :
simp_wf
obtain ⟨step, h, h'⟩ := h
cases step
· cases h
simp only [Std.Iterators.IterM.IsPlausibleStep, Std.Iterators.Iterator.IsPlausibleStep] at h'
split at h'
· next heq =>
rw [heq]
rcases h' with ⟨np, h1', h2'⟩ | h'
· rw [h2']
apply Prod.Lex.right'
· simp
· have haux := np.isValidForSlice.le_utf8ByteSize
simp [Slice.Pos.lt_iff, String.Pos.Raw.le_iff, String.Pos.Raw.lt_iff] at h1' haux ⊢
omega
· apply Prod.Lex.left
simp [h']
· next heq =>
rw [heq]
rcases h' with ⟨np, sp, h1', h2', h3'⟩ | h'
· rw [h3']
apply Prod.Lex.right'
· simp
· simp [String.Pos.Raw.le_iff, String.Pos.Raw.lt_iff] at h1' h2' ⊢
omega
· apply Prod.Lex.left
simp [h']
· contradiction
· cases h'
all_goals try
cases h
revert h'
simp only [Std.Iterators.IterM.IsPlausibleStep, Std.Iterators.Iterator.IsPlausibleStep]
match it.internalState with
| .emptyBefore pos =>
rintro (⟨h, h'⟩|h') <;> simp [h', ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def]
| .emptyAt pos h =>
simp only [forall_exists_index, and_imp]
intro x hx h
have := x.isValidForSlice.le_utf8ByteSize
simp [h, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, Pos.lt_iff,
Pos.Raw.lt_iff, Pos.Raw.le_iff] at hx ⊢ this
omega
| .proper .. =>
rintro (⟨newStackPos, newNeedlePos, h₁, h₂, (h|⟨rfl, h⟩)⟩|h)
· simp [h₂, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, h]
· simpa [h₂, ForwardSliceSearcher.toOption, Option.lt, Prod.lex_def, Pos.Raw.lt_iff]
· simp [h, ForwardSliceSearcher.toOption, Option.lt]
| .atEnd .. => simp
· cases h

@[no_expose]
instance : Std.Iterators.Finite (ForwardSliceSearcher s) Id :=
.of_finitenessRelation finitenessRelation

instance : Std.Iterators.IteratorCollect (ForwardSliceSearcher s) Id Id :=
.defaultImplementation

instance : Std.Iterators.IteratorLoop (ForwardSliceSearcher s) Id Id :=
.defaultImplementation

Expand Down
Loading
Loading