Skip to content

Commit b39ee8a

Browse files
authored
feat: add minimal support for getEntry/getEntry?/getEntry!/getEntryD for DHashMap (#11076)
This PR adds `getEntry`/`getEntry?`/`getEntry!`/`getEntryD` operation on DHashMap.
1 parent 9a3fb90 commit b39ee8a

File tree

9 files changed

+331
-4
lines changed

9 files changed

+331
-4
lines changed

src/Std/Data/DHashMap/Basic.lean

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,18 @@ end
190190
@[inline, inherit_doc Raw.getKeyD] def getKeyD (m : DHashMap α β) (a : α) (fallback : α) : α :=
191191
Raw₀.getKeyD ⟨m.1, m.2.size_buckets_pos⟩ a fallback
192192

193+
@[inline, inherit_doc Raw.getEntry?] def getEntry? (m : DHashMap α β) (a : α) : Option ((a : α) × β a) :=
194+
Raw₀.getEntry? ⟨m.1, m.2.size_buckets_pos⟩ a
195+
196+
@[inline, inherit_doc Raw.getEntry] def getEntry (m : DHashMap α β) (a : α) (h : a ∈ m) : (a : α) × β a :=
197+
Raw₀.getEntry ⟨m.1, m.2.size_buckets_pos⟩ a h
198+
199+
@[inline, inherit_doc Raw.getEntry!] def getEntry! [Inhabited ((a : α) × β a)] (m : DHashMap α β) (a : α) : (a : α) × β a :=
200+
Raw₀.getEntry! ⟨m.1, m.2.size_buckets_pos⟩ a
201+
202+
@[inline, inherit_doc Raw.getEntryD] def getEntryD (m : DHashMap α β) (a : α) (fallback : (a : α) × β a) : (a : α) × β a :=
203+
Raw₀.getEntryD ⟨m.1, m.2.size_buckets_pos⟩ a fallback
204+
193205
@[inline, inherit_doc Raw.size] def size (m : DHashMap α β) : Nat :=
194206
m.1.size
195207

src/Std/Data/DHashMap/Internal/AssocList/Basic.lean

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ def getCast? [BEq α] [LawfulBEq α] (a : α) : AssocList α β → Option (β a
107107
| cons k v es => if h : k == a then some (cast (congrArg β (eq_of_beq h)) v)
108108
else es.getCast? a
109109

110+
/-- Internal implementation detail of the hash map -/
111+
def getEntry? [BEq α] (a : α) : (l : AssocList α β) → Option ((a : α) × β a)
112+
| nil => none
113+
| cons k v es => if k == a then some ⟨k, v⟩
114+
else es.getEntry? a
115+
110116
/-- Internal implementation detail of the hash map -/
111117
def contains [BEq α] (a : α) : AssocList α β → Bool
112118
| nil => false
@@ -122,6 +128,21 @@ def getCast [BEq α] [LawfulBEq α] (a : α) : (l : AssocList α β) → l.conta
122128
| cons k v es, h => if hka : k == a then cast (congrArg β (eq_of_beq hka)) v
123129
else es.getCast a (by rw [← h, contains, Bool.of_not_eq_true hka, Bool.false_or])
124130

131+
/-- Internal implementation detail of the hash map -/
132+
def getEntry [BEq α] (a : α) : (l : AssocList α β) → l.contains a → (a : α) × β a
133+
| cons k v es, h => if hka : k == a then ⟨k, v⟩
134+
else es.getEntry a (by rw [← h, contains, Bool.of_not_eq_true hka, Bool.false_or])
135+
136+
/-- Internal implementation detail of the hash map -/
137+
def getEntryD [BEq α] (a : α) (fallback : (a : α) × β a) : AssocList α β → (a : α) × β a
138+
| nil => fallback
139+
| cons k v es => if k == a then ⟨k, v⟩ else es.getEntryD a fallback
140+
141+
/-- Internal implementation detail of the hash map -/
142+
def getEntry! [BEq α] (a : α) [Inhabited ((a : α) × β a)] : AssocList α β → (a : α) × β a
143+
| nil => default
144+
| cons k v es => if k == a then ⟨k, v⟩ else es.getEntry! a
145+
125146
/-- Internal implementation detail of the hash map -/
126147
def getKey [BEq α] (a : α) : (l : AssocList α β) → l.contains a → α
127148
| cons k _ es, h => if hka : k == a then k

src/Std/Data/DHashMap/Internal/AssocList/Lemmas.lean

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,38 @@ theorem get_eq {β : Type v} [BEq α] {l : AssocList α (fun _ => β)} {a : α}
8282
· simp [contains] at h
8383
next k v t ih => simp only [get, toList_cons, List.getValue_cons, ih]
8484

85+
@[simp]
86+
theorem getEntry_eq [BEq α] {l : AssocList α β} {a : α} {h} :
87+
l.getEntry a h = List.getEntry a l.toList (contains_eq.symm.trans h) := by
88+
induction l
89+
· simp [contains] at h
90+
next k v t ih =>
91+
simp only [getEntry, toList_cons, List.getEntry_cons, ih]
92+
93+
@[simp]
94+
theorem getEntry?_eq [BEq α] {l : AssocList α β} {a : α} :
95+
l.getEntry? a = List.getEntry? a l.toList := by
96+
induction l
97+
· simp only [getEntry?, toList_nil, getEntry?_nil]
98+
next k v t ih =>
99+
simp only [getEntry?, ih, toList_cons, getEntry?_cons, Bool.ite_eq_cond_iff]
100+
101+
@[simp]
102+
theorem getEntryD_eq [BEq α] {l : AssocList α β} {a : α} {fallback : (a : α) × β a} :
103+
l.getEntryD a fallback = List.getEntryD a fallback l.toList := by
104+
induction l
105+
· simp only [getEntryD, toList_nil, getEntryD_nil]
106+
next k v t ih =>
107+
simp only [getEntryD, ih, toList_cons, getEntryD_cons, Bool.ite_eq_cond_iff]
108+
109+
@[simp]
110+
theorem getEntry!_eq [BEq α] {l : AssocList α β} {a : α} [Inhabited ((a : α) × β a)] :
111+
l.getEntry! a = List.getEntry! a l.toList := by
112+
induction l
113+
· simp only [getEntry!, toList_nil, getEntry!_nil]
114+
next k v t ih =>
115+
simp only [getEntry!, ih, toList_cons, List.getEntry!_cons, Bool.ite_eq_cond_iff]
116+
85117
@[simp]
86118
theorem getCastD_eq [BEq α] [LawfulBEq α] {l : AssocList α β} {a : α} {fallback : β a} :
87119
l.getCastD a fallback = getValueCastD a l.toList fallback := by

src/Std/Data/DHashMap/Internal/Defs.lean

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,34 @@ def get [BEq α] [LawfulBEq α] [Hashable α] (m : Raw₀ α β) (a : α) (hma :
368368
let idx := mkIdx buckets.size h (hash a)
369369
buckets[idx.1].getCast a hma
370370

371+
/-- Internal implementation detail of the hash map -/
372+
def getEntry [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) (hma : m.contains a) :
373+
(a : α) × β a :=
374+
let ⟨⟨_, buckets⟩, h⟩ := m
375+
let idx := mkIdx buckets.size h (hash a)
376+
buckets[idx.1].getEntry a hma
377+
378+
/-- Internal implementation detail of the hash map -/
379+
def getEntry? [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) :
380+
Option ((a : α) × β a) :=
381+
let ⟨⟨_, buckets⟩, h⟩ := m
382+
let idx := mkIdx buckets.size h (hash a)
383+
buckets[idx.1].getEntry? a
384+
385+
/-- Internal implementation detail of the hash map -/
386+
def getEntryD [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) (fallback : (a : α) × β a) :
387+
(a : α) × β a :=
388+
let ⟨⟨_, buckets⟩, h⟩ := m
389+
let idx := mkIdx buckets.size h (hash a)
390+
buckets[idx.1].getEntryD a fallback
391+
392+
/-- Internal implementation detail of the hash map -/
393+
def getEntry! [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) [Inhabited ((a : α) × β a)] :
394+
(a : α) × β a :=
395+
let ⟨⟨_, buckets⟩, h⟩ := m
396+
let idx := mkIdx buckets.size h (hash a)
397+
buckets[idx.1].getEntry! a
398+
371399
/-- Internal implementation detail of the hash map -/
372400
def getD [BEq α] [LawfulBEq α] [Hashable α] (m : Raw₀ α β) (a : α) (fallback : β a) :
373401
β a :=

src/Std/Data/DHashMap/Internal/Model.lean

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,22 @@ def containsₘ [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) : Bool :=
294294
def getₘ [BEq α] [LawfulBEq α] [Hashable α] (m : Raw₀ α β) (a : α) (h : m.containsₘ a) : β a :=
295295
(bucket m.1.buckets m.2 a).getCast a h
296296

297+
/-- Internal implementation detail of the hash map -/
298+
def getEntryₘ [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) (h : m.containsₘ a) : (a : α) × β a :=
299+
(bucket m.1.buckets m.2 a).getEntry a h
300+
301+
/-- Internal implementation detail of the hash map -/
302+
def getEntry?ₘ [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) : Option ((a : α) × β a) :=
303+
(bucket m.1.buckets m.2 a).getEntry? a
304+
305+
/-- Internal implementation detail of the hash map -/
306+
def getEntryDₘ [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) (fallback : (a : α) × β a) : (a : α) × β a :=
307+
(bucket m.1.buckets m.2 a).getEntryD a fallback
308+
309+
/-- Internal implementation detail of the hash map -/
310+
def getEntry!ₘ [BEq α] [Hashable α] [Inhabited ((a : α) × β a)] (m : Raw₀ α β) (a : α) : (a : α) × β a :=
311+
(bucket m.1.buckets m.2 a).getEntry! a
312+
297313
/-- Internal implementation detail of the hash map -/
298314
def getDₘ [BEq α] [LawfulBEq α] [Hashable α] (m : Raw₀ α β) (a : α) (fallback : β a) : β a :=
299315
(m.get?ₘ a).getD fallback
@@ -452,6 +468,18 @@ theorem get?_eq_get?ₘ [BEq α] [LawfulBEq α] [Hashable α] (m : Raw₀ α β)
452468
theorem get_eq_getₘ [BEq α] [LawfulBEq α] [Hashable α] (m : Raw₀ α β) (a : α) (h : m.contains a) :
453469
get m a h = getₘ m a (by exact h) := (rfl)
454470

471+
theorem getEntry_eq_getEntryₘ [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) (h : m.contains a) :
472+
getEntry m a h = getEntryₘ m a (by exact h) := (rfl)
473+
474+
theorem getEntry?_eq_getEntry?ₘ [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) :
475+
getEntry? m a = getEntry?ₘ m a := (rfl)
476+
477+
theorem getEntryD_eq_getEntryDₘ [BEq α] [Hashable α] (m : Raw₀ α β) (a : α) (fallback : (a : α) × β a) :
478+
getEntryD m a fallback = getEntryDₘ m a fallback := (rfl)
479+
480+
theorem getEntry!_eq_getEntry!ₘ [BEq α] [Hashable α] [Inhabited ((a : α) × β a)] (m : Raw₀ α β) (a : α) :
481+
getEntry! m a = getEntry!ₘ m a := (rfl)
482+
455483
theorem getD_eq_getDₘ [BEq α] [LawfulBEq α] [Hashable α] (m : Raw₀ α β) (a : α) (fallback : β a) :
456484
getD m a fallback = getDₘ m a fallback := by
457485
simp [getD, getDₘ, get?ₘ, List.getValueCastD_eq_getValueCast?, bucket]

src/Std/Data/DHashMap/Internal/RawLemmas.lean

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ private meta def queryMap : Std.DHashMap Name (fun _ => Name × Array (MacroM (T
140140
`getKey, (``getKey_eq_getKey, #[`(getKey_of_perm _)])⟩,
141141
`getKeyD, (``getKeyD_eq_getKeyD, #[`(getKeyD_of_perm _)])⟩,
142142
`getKey!, (``getKey!_eq_getKey!, #[`(getKey!_of_perm _)])⟩,
143+
`getEntry, (``getEntry_eq_getEntry, #[`(getEntry_of_perm _)])⟩,
144+
`getEntry?, (``getEntry?_eq_getEntry?, #[`(getEntry?_of_perm _)])⟩,
145+
`getEntryD, (``getEntryD_eq_getEntryD, #[`(getEntryD_of_perm _)])⟩,
146+
`getEntry!, (``getEntry!_eq_getEntry!, #[`(getEntry!_of_perm _)])⟩,
143147
`toList, (``Raw.toList_eq_toListModel, #[])⟩,
144148
`keys, (``Raw.keys_eq_keys_toListModel, #[`(perm_keys_congr_left)])⟩,
145149
`Const.toList, (``Raw.Const.toList_eq_toListModel_map, #[`(perm_map_congr_left)])⟩,
@@ -2424,7 +2428,7 @@ theorem union_equiv_congr_left {m₃ : Raw₀ α β} [EquivBEq α] [LawfulHashab
24242428
(h₁ : m₁.val.WF) (h₂ : m₂.val.WF) (h₃ : m₃.val.WF) (equiv : m₁.1.Equiv m₂.1) :
24252429
(m₁.union m₃).1.Equiv (m₂.union m₃).1 := by
24262430
revert equiv
2427-
simp_to_model [union]
2431+
simp_to_model [Equiv, union]
24282432
intro equiv
24292433
apply List.insertList_perm_of_perm_first equiv
24302434
wf_trivial
@@ -2433,18 +2437,18 @@ theorem union_equiv_congr_right {m₃ : Raw₀ α β} [EquivBEq α] [LawfulHasha
24332437
(h₁ : m₁.val.WF) (h₂ : m₂.val.WF) (h₃ : m₃.val.WF) (equiv : m₂.1.Equiv m₃.1) :
24342438
(m₁.union m₂).1.Equiv (m₁.union m₃).1 := by
24352439
revert equiv
2436-
simp_to_model [union]
2440+
simp_to_model [Equiv, union]
24372441
intro equiv
24382442
apply @List.insertList_perm_of_perm_second _ _ _ _ (toListModel m₂.val.buckets) (toListModel m₃.val.buckets) (toListModel m₁.val.buckets) equiv
24392443
all_goals wf_trivial
24402444

24412445
theorem union_insert_right_equiv_insert_union [EquivBEq α] [LawfulHashable α] {p : (a : α) × β a}
24422446
(h₁ : m₁.val.WF) (h₂ : m₂.val.WF) :
24432447
(m₁.union (m₂.insert p.fst p.snd)).1.Equiv ((m₁.union m₂).insert p.fst p.snd).1 := by
2444-
simp_to_model [union, insert]
2448+
simp_to_model [Equiv, union, insert]
24452449
apply List.Perm.trans
24462450
. apply insertList_perm_of_perm_second
2447-
simp_to_model [insert]
2451+
simp_to_model [Equiv, insert]
24482452
. apply insertEntry_of_perm
24492453
. wf_trivial
24502454
. apply List.Perm.refl

src/Std/Data/DHashMap/Internal/WF.lean

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,11 +405,51 @@ theorem getₘ_eq_getValue [BEq α] [Hashable α] [LawfulBEq α] {m : Raw₀ α
405405
apply_bucket_with_proof hm a AssocList.getCast List.getValueCast AssocList.getCast_eq
406406
List.getValueCast_of_perm List.getValueCast_append_of_containsKey_eq_false
407407

408+
theorem getEntryₘ_eq_getEntry [BEq α] [PartialEquivBEq α] [Hashable α] [LawfulHashable α] {m : Raw₀ α β} (hm : Raw.WFImp m.1)
409+
{a : α} {h : m.containsₘ a} :
410+
m.getEntryₘ a h = List.getEntry a (toListModel m.1.buckets) (containsₘ_eq_containsKey hm ▸ h) :=
411+
apply_bucket_with_proof hm a AssocList.getEntry List.getEntry AssocList.getEntry_eq getEntry_of_perm getEntry_append_of_containsKey_eq_false
412+
413+
theorem getEntry?ₘ_eq_getEntry? [BEq α] [PartialEquivBEq α] [Hashable α] [LawfulHashable α] {m : Raw₀ α β} (hm : Raw.WFImp m.1)
414+
{a : α} :
415+
m.getEntry?ₘ a = List.getEntry? a (toListModel m.1.buckets) :=
416+
apply_bucket hm AssocList.getEntry?_eq getEntry?_of_perm getEntry?_append_of_containsKey_eq_false
417+
408418
theorem get_eq_getValueCast [BEq α] [Hashable α] [LawfulBEq α] {m : Raw₀ α β} (hm : Raw.WFImp m.1)
409419
{a : α} {h : m.contains a} :
410420
m.get a h = getValueCast a (toListModel m.1.buckets) (contains_eq_containsKey hm ▸ h) := by
411421
rw [get_eq_getₘ, getₘ_eq_getValue hm]
412422

423+
theorem getEntry_eq_getEntry [BEq α] [Hashable α] [PartialEquivBEq α] [LawfulHashable α] {m : Raw₀ α β} (hm : Raw.WFImp m.1)
424+
{a : α} {h : m.contains a} :
425+
m.getEntry a h = List.getEntry a (toListModel m.1.buckets) (contains_eq_containsKey hm ▸ h) := by
426+
rw [getEntry_eq_getEntryₘ, getEntryₘ_eq_getEntry hm]
427+
428+
theorem getEntry?_eq_getEntry? [BEq α] [Hashable α] [PartialEquivBEq α] [LawfulHashable α] {m : Raw₀ α β} (hm : Raw.WFImp m.1)
429+
{a : α} :
430+
m.getEntry? a = List.getEntry? a (toListModel m.1.buckets) := by
431+
rw [getEntry?_eq_getEntry?ₘ, getEntry?ₘ_eq_getEntry? hm]
432+
433+
theorem getEntryDₘ_eq_getEntryD [BEq α] [PartialEquivBEq α] [Hashable α] [LawfulHashable α] {m : Raw₀ α β} (hm : Raw.WFImp m.1)
434+
{a : α} {fallback : (a : α) × β a} :
435+
m.getEntryDₘ a fallback = List.getEntryD a fallback (toListModel m.1.buckets) :=
436+
apply_bucket hm AssocList.getEntryD_eq getEntryD_of_perm getEntryD_append_of_containsKey_eq_false
437+
438+
theorem getEntryD_eq_getEntryD [BEq α] [Hashable α] [PartialEquivBEq α] [LawfulHashable α] {m : Raw₀ α β} (hm : Raw.WFImp m.1)
439+
{a : α} {fallback : (a : α) × β a} :
440+
m.getEntryD a fallback = List.getEntryD a fallback (toListModel m.1.buckets) := by
441+
rw [getEntryD_eq_getEntryDₘ, getEntryDₘ_eq_getEntryD hm]
442+
443+
theorem getEntry!ₘ_eq_getEntry! [BEq α] [PartialEquivBEq α] [Hashable α] [LawfulHashable α] {m : Raw₀ α β} (hm : Raw.WFImp m.1)
444+
{a : α} [Inhabited ((a : α) × β a)] :
445+
m.getEntry!ₘ a = List.getEntry! a (toListModel m.1.buckets) :=
446+
apply_bucket hm AssocList.getEntry!_eq getEntry!_of_perm getEntry!_append_of_containsKey_eq_false
447+
448+
theorem getEntry!_eq_getEntry! [BEq α] [Hashable α] [PartialEquivBEq α] [LawfulHashable α] {m : Raw₀ α β} (hm : Raw.WFImp m.1)
449+
{a : α} [Inhabited ((a : α) × β a)] :
450+
m.getEntry! a = List.getEntry! a (toListModel m.1.buckets) := by
451+
rw [getEntry!_eq_getEntry!ₘ, getEntry!ₘ_eq_getEntry! hm]
452+
413453
theorem get!ₘ_eq_getValueCast! [BEq α] [Hashable α] [LawfulBEq α] {m : Raw₀ α β}
414454
(hm : Raw.WFImp m.1) {a : α} [Inhabited (β a)] :
415455
m.get!ₘ a = getValueCast! a (toListModel m.1.buckets) := by

src/Std/Data/DHashMap/Raw.lean

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,41 @@ If no panic occurs the result is guaranteed to be pointer equal to the key in th
300300
Raw₀.getKey! ⟨m, h⟩ a
301301
else default -- will never happen for well-formed inputs
302302

303+
/--
304+
Checks if a mapping for the given key exists and returns the key-value pair if it does, otherwise `none`.
305+
The key in the returned pair will be `BEq` to the input `a`.
306+
-/
307+
@[inline] def getEntry? [BEq α] [Hashable α] (m : Raw α β) (a : α) : Option ((a : α) × β a) :=
308+
if h : 0 < m.buckets.size then
309+
Raw₀.getEntry? ⟨m, h⟩ a
310+
else none -- will never happen for well-formed inputs
311+
312+
/--
313+
Retrieves the key-value pair, whose key matches `a`. Ensures that such a mapping exists by
314+
requiring a proof of `a ∈ m`. The key in the returned pair will be `BEq` to the input `a`.
315+
-/
316+
@[inline] def getEntry [BEq α] [Hashable α] (m : Raw α β) (a : α) (h : a ∈ m) : (a : α) × β a :=
317+
Raw₀.getEntry ⟨m, by change dite .. = true at h; split at h <;> simp_all⟩ a
318+
(by change dite .. = true at h; split at h <;> simp_all)
319+
320+
/--
321+
Checks if a mapping for the given key exists and returns the key-value pair if it does, otherwise `fallback`.
322+
The key in the returned pair will be `BEq` to the input `a`.
323+
-/
324+
@[inline] def getEntryD [BEq α] [Hashable α] (m : Raw α β) (a : α) (fallback : (a : α) × β a) : (a : α) × β a :=
325+
if h : 0 < m.buckets.size then
326+
Raw₀.getEntryD ⟨m, h⟩ a fallback
327+
else fallback -- will never happen for well-formed inputs
328+
329+
/--
330+
Checks if a mapping for the given key exists and returns the key-value pair if it does, otherwise panics.
331+
The key in the returned pair will be `BEq` to the input `a`.
332+
-/
333+
@[inline] def getEntry! [BEq α] [Hashable α] [Inhabited ((a : α) × β a)] (m : Raw α β) (a : α) : (a : α) × β a :=
334+
if h : 0 < m.buckets.size then
335+
Raw₀.getEntry! ⟨m, h⟩ a
336+
else default -- will never happen for well-formed inputs
337+
303338
/--
304339
Returns `true` if the hash map contains no mappings.
305340

0 commit comments

Comments
 (0)