Skip to content

Commit e35d651

Browse files
wkrozowskiTwoFX
andauthored
feat: add intersection on DHashMap (#11112)
This PR adds intersection operation on `DHashMap`/`HashMap`/`HashSet` and provides several lemmas about its behaviour. --------- Co-authored-by: Markus Himmel <[email protected]>
1 parent 1a4c3ca commit e35d651

File tree

19 files changed

+2978
-51
lines changed

19 files changed

+2978
-51
lines changed

src/Std/Data/DHashMap/Basic.lean

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,18 @@ This function always merges the smaller map into the larger map, so the expected
336336
inner := Raw₀.union ⟨m₁.1, m₁.2.size_buckets_pos⟩ ⟨m₂.1, m₂.2.size_buckets_pos⟩
337337
wf := Std.DHashMap.Raw.WF.union₀ m₁.2 m₂.2
338338

339+
/--
340+
Computes the intersection of the given hash maps. The result will only contain entries from the first map.
341+
342+
This function always iterates through the smaller map, so the expected runtime is
343+
`O(min(m₁.size, m₂.size))`.
344+
-/
345+
@[inline] def inter [BEq α] [Hashable α] (m₁ m₂ : DHashMap α β) : DHashMap α β where
346+
inner := Raw₀.inter ⟨m₁.1, m₁.2.size_buckets_pos⟩ ⟨m₂.1, m₂.2.size_buckets_pos⟩
347+
wf := Std.DHashMap.Raw.WF.inter₀ m₁.2 m₂.2
348+
339349
instance [BEq α] [Hashable α] : Union (DHashMap α β) := ⟨union⟩
350+
instance [BEq α] [Hashable α] : Inter (DHashMap α β) := ⟨inter⟩
340351

341352
section Unverified
342353

src/Std/Data/DHashMap/Internal/Defs.lean

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,10 +461,25 @@ def insertMany {ρ : Type w} [ForIn Id ρ ((a : α) × β a)] [BEq α] [Hashable
461461
r := ⟨r.1.insertIfNew a b, fun _ h hm => h (r.2 _ h hm)⟩
462462
return r
463463

464+
/-- Internal implementation detail of the hash map -/
465+
@[inline]
466+
def interSmallerFn [BEq α] [Hashable α] (m sofar : Raw₀ α β) (k : α) : Raw₀ α β :=
467+
match m.getEntry? k with
468+
| some kv' => sofar.insert kv'.1 kv'.2
469+
| none => sofar
470+
471+
/-- Internal implementation detail of the hash map -/
472+
def interSmaller [BEq α] [Hashable α] (m₁ : Raw₀ α β) (m₂ : Raw α β) : Raw₀ α β :=
473+
(m₂.fold (fun sofar k _ => interSmallerFn m₁ sofar k) emptyWithCapacity)
474+
464475
/-- Internal implementation detail of the hash map -/
465476
@[inline] def union [BEq α] [Hashable α] (m₁ m₂ : Raw₀ α β) : Raw₀ α β :=
466477
if m₁.1.size ≤ m₂.1.size then (m₂.insertManyIfNew m₁.1).1 else (m₁.insertMany m₂.1).1
467478

479+
/-- Internal implementation detail of the hash map -/
480+
def inter [BEq α] [Hashable α] (m₁ m₂ : Raw₀ α β) : Raw₀ α β :=
481+
if m₁.1.size ≤ m₂.1.size then m₁.filter fun k _ => m₂.contains k else interSmaller m₁ m₂
482+
468483
section
469484

470485
variable {β : Type v}

src/Std/Data/DHashMap/Internal/Model.lean

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,12 @@ def unionₘ [BEq α] [Hashable α] (m₁ m₂ : Raw₀ α β) : Raw₀ α β :=
420420
else
421421
insertListₘ m₁ (toListModel m₂.1.buckets)
422422

423+
/-- Internal implementation detail of the hash map -/
424+
def interSmallerFnₘ [BEq α] [Hashable α] (m sofar : Raw₀ α β) (k : α) : Raw₀ α β :=
425+
match m.getEntry?ₘ k with
426+
| some kv' => sofar.insertₘ kv'.1 kv'.2
427+
| none => sofar
428+
423429
section
424430

425431
variable {β : Type v}
@@ -664,6 +670,14 @@ theorem insertManyIfNew_eq_insertListIfNewₘ [BEq α] [Hashable α] (m : Raw₀
664670
simp only [List.foldl_cons, insertListIfNewₘ]
665671
apply ih
666672

673+
theorem interSmallerFn_eq_interSmallerFnₘ [BEq α] [Hashable α] (m sofar : Raw₀ α β) (k : α) :
674+
interSmallerFn m sofar k = interSmallerFnₘ m sofar k := by
675+
rw [interSmallerFn, interSmallerFnₘ]
676+
rw [getEntry?_eq_getEntry?ₘ]
677+
congr
678+
ext
679+
rw [insert_eq_insertₘ]
680+
667681
section
668682

669683
variable {β : Type v}

src/Std/Data/DHashMap/Internal/Raw.lean

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,10 @@ theorem union_eq [BEq α] [Hashable α] {m₁ m₂ : Raw α β} (h₁ : m₁.WF)
145145
m₁.union m₂ = Raw₀.union ⟨m₁, h₁.size_buckets_pos⟩ ⟨m₂, h₂.size_buckets_pos⟩ := by
146146
simp [Raw.union, h₁.size_buckets_pos, h₂.size_buckets_pos]
147147

148+
theorem inter_eq [BEq α] [Hashable α] {m₁ m₂ : Raw α β} (h₁ : m₁.WF) (h₂ : m₂.WF) :
149+
m₁.inter m₂ = Raw₀.inter ⟨m₁, h₁.size_buckets_pos⟩ ⟨m₂, h₂.size_buckets_pos⟩ := by
150+
simp [Raw.inter, h₁.size_buckets_pos, h₂.size_buckets_pos]
151+
148152
section
149153

150154
variable {β : Type v}

src/Std/Data/DHashMap/Internal/RawLemmas.lean

Lines changed: 349 additions & 20 deletions
Large diffs are not rendered by default.

src/Std/Data/DHashMap/Internal/WF.lean

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,23 @@ theorem foldM_eq_foldlM_toListModel {δ : Type w} {m : Type w → Type w'} [Mona
167167
funext init'
168168
rw [ih]
169169

170+
theorem fold_induction {δ : Type w}
171+
{f : δ → (a : α) → β a → δ} {init : δ} {b : Raw α β} {P : δ → Prop}
172+
(base : P init) (step : ∀ acc a b , P acc → P (f acc a b)) :
173+
P (b.fold f init) := by
174+
simp [Raw.fold, Raw.foldM, ← Array.foldlM_toList]
175+
induction b.buckets.toList generalizing init with
176+
| nil => simp [base]
177+
| cons hd tl ih =>
178+
apply ih
179+
induction hd generalizing init with
180+
| nil => simp [AssocList.foldlM, pure, base]
181+
| cons hda hdb tl ih =>
182+
simp only [AssocList.foldlM, pure_bind]
183+
apply ih
184+
apply step
185+
exact base
186+
170187
theorem fold_eq_foldl_toListModel {l : Raw α β} {f : γ → (a : α) → β a → γ} {init : γ} :
171188
l.fold f init = (toListModel l.buckets).foldl (fun a b => f a b.1 b.2) init := by
172189
simp [Raw.fold, foldM_eq_foldlM_toListModel]
@@ -1107,7 +1124,6 @@ theorem insertMany_eq_insertListₘ_toListModel [BEq α] [Hashable α] (m m₂ :
11071124
simp only [List.foldl_cons, insertListₘ]
11081125
apply ih
11091126

1110-
11111127
theorem insertManyIfNew_eq_insertListIfNewₘ_toListModel [BEq α] [Hashable α] (m m₂ : Raw₀ α β) :
11121128
insertManyIfNew m m₂.1 = insertListIfNewₘ m (toListModel m₂.1.buckets) := by
11131129
simp only [insertManyIfNew, bind_pure_comp, map_pure, bind_pure]
@@ -1142,6 +1158,20 @@ theorem toListModel_unionₘ [BEq α] [Hashable α] [EquivBEq α] [LawfulHashabl
11421158
· exact toListModel_insertListIfNewₘ ‹_›
11431159
· exact toListModel_insertListₘ ‹_›
11441160

1161+
theorem wfImp_inter [BEq α] [EquivBEq α] [Hashable α] [LawfulHashable α]
1162+
{m₁ m₂ : Raw α β} {h₁ : 0 < m₁.buckets.size} {h₂ : 0 < m₂.buckets.size} (wh₁ : Raw.WFImp m₁) :
1163+
Raw.WFImp (Raw₀.inter ⟨m₁, h₁⟩ ⟨m₂, h₂⟩).val := by
1164+
rw [inter]
1165+
split
1166+
· apply wfImp_filter wh₁
1167+
· rw [interSmaller]
1168+
apply @Raw.fold_induction _ β _ (fun sofar k x => interSmallerFn ⟨m₁, h₁⟩ sofar k) emptyWithCapacity m₂ (Raw.WFImp ·.val) wfImp_emptyWithCapacity
1169+
intro acc a b wf
1170+
rw [interSmallerFn]
1171+
split
1172+
· apply wfImp_insert wf
1173+
· apply wf
1174+
11451175
end Raw₀
11461176

11471177
namespace Raw
@@ -1163,6 +1193,7 @@ theorem WF.out [BEq α] [Hashable α] [i₁ : EquivBEq α] [i₂ : LawfulHashabl
11631193
| constModify₀ _ h => exact Raw₀.Const.wfImp_modify (by apply h)
11641194
| alter₀ _ h => exact Raw₀.wfImp_alter (by apply h)
11651195
| constAlter₀ _ h => exact Raw₀.Const.wfImp_alter (by apply h)
1196+
| inter₀ _ _ h _ => exact Raw₀.wfImp_inter (by apply h)
11661197

11671198
end Raw
11681199

@@ -1312,13 +1343,84 @@ theorem wf_union₀ [BEq α] [Hashable α] [EquivBEq α] [LawfulHashable α]
13121343
· exact wf_insertManyIfNew₀ ‹_›
13131344
· exact wf_insertMany₀ ‹_›
13141345

1346+
13151347
theorem toListModel_union [BEq α] [Hashable α] [EquivBEq α] [LawfulHashable α] {m₁ m₂ : Raw₀ α β}
13161348
(h₁ : Raw.WFImp m₁.1) (h₂ : Raw.WFImp m₂.1) :
13171349
Perm (toListModel (m₁.union m₂).1.buckets)
13181350
(List.insertList (toListModel m₁.1.buckets) (toListModel m₂.1.buckets)) := by
13191351
rw [union_eq_unionₘ]
13201352
exact toListModel_unionₘ h₁ h₂
13211353

1354+
1355+
/-! # `inter` -/
1356+
1357+
theorem wfImp_interSmallerFnₘ [BEq α] [EquivBEq α] [Hashable α] [LawfulHashable α] (m₁ : Raw₀ α β) (m₂ : Raw₀ α β)
1358+
(hm₂ : Raw.WFImp m₂.1) (k : α) : Raw.WFImp (m₁.interSmallerFnₘ m₂ k).1 := by
1359+
rw [interSmallerFnₘ]
1360+
split
1361+
· exact wfImp_insertₘ hm₂
1362+
· exact hm₂
1363+
1364+
/-- Internal implementation detail of the hash map -/
1365+
def interSmallerₘ [BEq α] [Hashable α] (m₁ : Raw₀ α β) (m₂ : Raw α β) : Raw₀ α β :=
1366+
m₂.fold (fun sofar k _ => interSmallerFnₘ m₁ sofar k) emptyWithCapacity
1367+
1368+
theorem interSmaller_eq_interSmallerₘ [BEq α] [Hashable α] (m₁ : Raw₀ α β) (m₂ : Raw α β) :
1369+
m₁.interSmaller m₂ = m₁.interSmallerₘ m₂ := by
1370+
rw [interSmaller, interSmallerₘ]
1371+
simp only [interSmallerFn_eq_interSmallerFnₘ]
1372+
1373+
theorem foldl_perm_cong [BEq α] {init₁ init₂ : List ((a : α) × β a)} {l : List ((a : α) × β a)}
1374+
{f : List ((a : α) × β a) → ((a : α) × β a) → List ((a : α) × β a)} (h₁ : Perm init₁ init₂)
1375+
(h₂ : ∀ h l₁ l₂, (w : DistinctKeys l₁) → Perm l₁ l₂ → Perm (f l₁ h) (f l₂ h) ∧ DistinctKeys (f l₁ h))
1376+
(h₃ : DistinctKeys init₁)
1377+
: Perm (List.foldl f init₁ l) (List.foldl f init₂ l) := by
1378+
induction l generalizing init₁ init₂
1379+
case nil =>
1380+
simp only [foldl_nil, h₁]
1381+
case cons h t ih =>
1382+
simp only [foldl_cons]
1383+
apply ih
1384+
· exact (h₂ h init₁ init₂ h₃ h₁).1
1385+
· exact (h₂ h init₁ init₂ h₃ h₁).2
1386+
1387+
theorem toListModel_interSmallerFnₘ [BEq α] [EquivBEq α] [Hashable α] [LawfulHashable α] (m sofar : Raw₀ α β)
1388+
(l : List ((a : α) × β a))
1389+
(hm : Raw.WFImp m.1) (hs : Raw.WFImp sofar.1) (k : α) (hml : toListModel sofar.1.buckets ~ l) :
1390+
Perm (toListModel ((interSmallerFnₘ m sofar k).1.buckets))
1391+
(List.interSmallerFn (toListModel m.1.buckets) l k) := by
1392+
rw [interSmallerFnₘ, getEntry?ₘ_eq_getEntry? hm, List.interSmallerFn]
1393+
split
1394+
· simpa [*] using (toListModel_insertₘ hs).trans (List.insertEntry_of_perm hs.distinct hml)
1395+
· simp [*]
1396+
1397+
theorem toListModel_interSmallerₘ [BEq α] [EquivBEq α] [Hashable α] [LawfulHashable α]
1398+
(m₁ : Raw₀ α β) (m₂ : Raw α β) (hm₁ : Raw.WFImp m₁.1) :
1399+
toListModel (m₁.interSmallerₘ m₂).1.buckets ~
1400+
List.interSmaller (toListModel m₁.1.buckets) (toListModel m₂.buckets) := by
1401+
rw [interSmallerₘ, Raw.fold_eq_foldl_toListModel, List.interSmaller]
1402+
generalize toListModel m₂.buckets = l
1403+
suffices ∀ m l', Raw.WFImp m.1 → toListModel m.1.buckets ~ l' → toListModel (foldl (fun a b => m₁.interSmallerFnₘ a b.fst) m l).val.buckets ~
1404+
foldl (fun sofar kv => List.interSmallerFn (toListModel m₁.val.buckets) sofar kv.fst) l' l by
1405+
simpa using this emptyWithCapacity [] wfImp_emptyWithCapacity (by simp)
1406+
intro m l' hm hml'
1407+
induction l generalizing m l' with
1408+
| nil => simpa
1409+
| cons ht tl ih =>
1410+
rw [List.foldl_cons, List.foldl_cons]
1411+
exact ih _ _ (wfImp_interSmallerFnₘ _ _ hm _) (toListModel_interSmallerFnₘ _ _ _ hm₁ hm _ hml')
1412+
1413+
theorem toListModel_inter [BEq α] [EquivBEq α] [Hashable α] [LawfulHashable α] (m₁ m₂ : Raw₀ α β) (hm₁ : Raw.WFImp m₁.1) (hm₂ : Raw.WFImp m₂.1) :
1414+
Perm (toListModel (m₁.inter m₂).1.buckets) ((toListModel m₁.1.buckets).filter fun p => containsKey p.1 (toListModel m₂.1.buckets) ) := by
1415+
simp [inter]
1416+
split
1417+
· rw [filter_eq_filterₘ]
1418+
simp only [contains_eq_containsKey hm₂]
1419+
exact toListModel_filterₘ
1420+
· rw [interSmaller_eq_interSmallerₘ]
1421+
exact Perm.trans (toListModel_interSmallerₘ _ _ hm₁)
1422+
(interSmaller_perm_filter _ _ hm₁.distinct)
1423+
13221424
/-! # `Const.insertListₘ` -/
13231425

13241426
theorem Const.toListModel_insertListₘ {β : Type v} [BEq α] [Hashable α] [EquivBEq α]

0 commit comments

Comments
 (0)