Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/Std/Data/DHashMap/Raw.lean
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,15 @@ section Unverified

/-! We currently do not provide lemmas for the functions below. -/

/-- Partition a hash map into two hash map based on a predicate. -/
@[inline] def partition [BEq α] [Hashable α] (f : (a : α) → β a → Bool)
(m : Raw α β) : Raw α β × Raw α β :=
m.fold (init := (∅, ∅)) fun ⟨l, r⟩ a b =>
if f a b then
(l.insert a b, r)
else
(l, r.insert a b)

/-- Returns a list of all values present in the hash map in some order. -/
@[inline] def values {β : Type v} (m : Raw α (fun _ => β)) : List β :=
Internal.foldRev (fun acc _ v => v :: acc) [] m
Expand Down
232 changes: 232 additions & 0 deletions src/Std/Data/DHashMap/RawLemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5545,6 +5545,238 @@ end Const

end map

theorem size_partition {p : (a : α) → β a → Bool} (h : m.WF) [EquivBEq α] [LawfulHashable α] :
(m.partition p).1.size + (m.partition p).2.size = m.size := by
simp [partition, fold_eq_foldl_toList h]
let f : Raw α β × Raw α β → List ((a : α) × β a) → Raw α β × Raw α β :=
fun pair l => List.foldl (fun a b => if p b.1 b.2 = true
then (a.1.insert b.1 b.2, a.2) else (a.1, a.2.insert b.1 b.2)) pair l
suffices ∀ (l : List ((a : α) × β a)) (m₁ m₂ : Raw α β) (h₁ : m₁.WF) (h₂ : m₂.WF)
(h₃ : ∀ x ∈ l, ¬ x.1 ∈ m₁ ∧ ¬ x.1 ∈ m₂)
(h₄ : l.Pairwise (fun a b => (a.1 == b.1) = false)), (f (m₁, m₂) l).1.size + (f (m₁, m₂) l).2.size = l.length + m₁.size + m₂.size by
have := this m.toList ∅ ∅ WF.empty WF.empty (by simp) (distinct_keys_toList h)
simp [f, length_toList h] at this
apply this
intro l
induction l with
| nil => simp [f]
| cons hd tl ih =>
intro m₁ m₂ h₁ h₂ h₃ h₄
simp [f]
by_cases h : p hd.1 hd.2 = true
· simp [h]
rw [ih _ _ (WF.insert h₁) h₂]
· rw [size_insert h₁]
simp [(h₃ hd (by simp))]
omega
· intro x hx
simp at h₄
simp [h₁, h₄.1 x hx, h₃ x (by simp [hx])]
· apply List.Pairwise.of_cons h₄
· simp [h]
rw [ih _ _ h₁ (WF.insert h₂)]
· rw [size_insert h₂]
simp [(h₃ hd (by simp))]
omega
· intro x hx
simp at h₄
simp [h₂, h₄.1 x hx, h₃ x (by simp [hx])]
· apply List.Pairwise.of_cons h₄

theorem partition_not_fst_eq_partition_snd {p : (a : α) → β a → Bool} (h : m.WF) [EquivBEq α] [LawfulHashable α] :
(m.partition (fun a b => ! p a b)).fst = (m.partition p).snd := by
simp [partition, fold_eq_foldl_toList h]
let f : Raw α β × Raw α β → List ((a : α) × β a) → Raw α β × Raw α β :=
fun pair l => List.foldl (fun a b => if p b.1 b.2 = true
then (a.1.insert b.1 b.2, a.2) else (a.1, a.2.insert b.1 b.2)) pair l
let f' : Raw α β × Raw α β → List ((a : α) × β a) → Raw α β × Raw α β :=
fun pair l => List.foldl (fun a b => if p b.1 b.2 = false
then (a.1.insert b.1 b.2, a.2) else (a.1, a.2.insert b.1 b.2)) pair l
suffices ∀ (l : List ((a : α) × β a)) (m₁ m₂ : Raw α β) (h₁ : m₁.WF) (h₂ : m₂.WF),
(f' (m₁, m₂) l).fst = (f (m₂, m₁) l).snd from this _ _ _ WF.empty WF.empty
intro l
induction l with
| nil => simp [f, f']
| cons hd tl ih =>
intro m₁ m₂ h₁ h₂
simp [f, f']
by_cases hhd : p hd.fst hd.snd = true
· simp [hhd]
rw [ih _ _ h₁ (WF.insert h₂)]
· simp [hhd]
rw [ih _ _ (WF.insert h₁) h₂]

theorem wf_partition_fst [EquivBEq α] [LawfulHashable α] (h : m.WF) {p : (a : α) → β a → Bool} :
(m.partition p).1.WF := by
simp [partition, fold_eq_foldl_toList h]
let f : Raw α β × Raw α β → List ((a : α) × β a) → Raw α β × Raw α β :=
fun pair l => List.foldl (fun a b => if p b.1 b.2 = true
then (a.1.insert b.1 b.2, a.2) else (a.1, a.2.insert b.1 b.2)) pair l
suffices ∀ (l : List ((a : α) × β a)) (m₁ m₂ : Raw α β) (h₁ : m₁.WF) (h₂ : m₂.WF),
(f (m₁, m₂) l).1.WF from this _ _ _ WF.empty WF.empty
intro l
induction l with
| nil =>
simp [f]
intro m₁ _ h₁ _
apply h₁
| cons hd tl ih =>
intro m₁ m₂ h₁ h₂
simp [f]
by_cases hhd : p hd.fst hd.snd = true
· simp [hhd]
apply ih _ _ (WF.insert h₁) h₂
· simp [hhd]
apply ih _ _ h₁ (WF.insert h₂)


theorem Perm_helper {α : Type u} (l1 l2 : List α) (h₁ : l1.Nodup) (h₂ : l2.Nodup) (h₃ : ∀ (a : α), a ∈ l1 ↔ a ∈ l2) :
l1.Perm l2 := by
sorry

theorem neg_mem_toList_empty [EquivBEq α] [LawfulHashable α] {x : (a : α) × β a} : ¬ x ∈ (∅ : Raw α β).toList := by
have := isEmpty_toList (α := α) (β := β) (m := ∅) WF.empty
simp only [isEmpty_empty, List.isEmpty_iff] at this
simp [this]

theorem mem_toList_insert [EquivBEq α] [LawfulHashable α] (h : m.WF) {k : α} {v : β k} {x : (a : α) × β a} (h' : ¬ k ∈ m) :
x ∈ (m.insert k v).toList ↔ x ∈ m.toList ∨ x = ⟨k, v⟩ := by
sorry

theorem mem_toList_fst_partition [EquivBEq α] [LawfulHashable α] (h : m.WF) {p : (a : α) → β a → Bool} (x : (a : α) × β a):
x ∈ (m.partition p).1.toList ↔ x ∈ m.toList ∧ p x.1 x.2 = true := by
simp [partition, fold_eq_foldl_toList h]
let f : Raw α β × Raw α β → List ((a : α) × β a) → Raw α β × Raw α β :=
fun pair l => List.foldl (fun a b => if p b.1 b.2 = true
then (a.1.insert b.1 b.2, a.2) else (a.1, a.2.insert b.1 b.2)) pair l
suffices ∀ (l : List ((a : α) × β a)) (m₁ m₂ : Raw α β) (h₁ : m₁.WF) (h₂ : m₂.WF)
(h₃ : ∀ x ∈ l, ¬ x.1 ∈ m₁)
(h₄ : l.Pairwise (fun a b => (a.1 == b.1) = false)),
x ∈ (f (m₁, m₂) l).1.toList ↔ (x ∈ l ∧ p x.1 x.2 = true) ∨ x ∈ m₁.toList by
specialize this m.toList ∅ ∅ WF.empty WF.empty (by simp) (distinct_keys_toList h)
rw [this]
simp [neg_mem_toList_empty]
intro l
induction l with
| nil =>
simp [f]
| cons hd tl ih =>
intro m₁ m₂ h₁ h₂ h₃ h₄
simp [f]
by_cases hhd : p hd.1 hd.2
· simp [hhd]
rw [ih _ _ (WF.insert h₁) h₂]
· rw [mem_toList_insert h₁ (h₃ hd (by simp))]
constructor
· intro h
cases h with
| inl h =>
apply Or.inl (And.intro (Or.inr h.1) h.2)
| inr h =>
cases h with
| inl h =>
apply Or.inr h
| inr h =>
rw [h]
simp [hhd]
· intro h
cases h with
| inl h =>
rcases h with ⟨h, h'⟩
cases h with
| inl h =>
rw [h]
simp [hhd]
| inr h =>
simp [h, h']
| inr h =>
simp [h]
· intro x hx
simp [mem_insert h₁]
constructor
· simp at h₄
apply h₄.1 x hx
· apply h₃ x (by simp [hx])
· simp at h₄
apply h₄.2
· simp [hhd]
rw [ih _ _ h₁ (WF.insert h₂)]
· constructor
· intro h
cases h with
| inl h =>
apply Or.inl (And.intro (Or.inr h.1) h.2)
| inr h =>
apply Or.inr h
· intro h
cases h with
| inl h =>
rcases h with ⟨h, h'⟩
cases h with
| inl h =>
rw [h] at h'
contradiction
| inr h =>
simp [h, h']
| inr h =>
simp [h]
· simp at h₃
apply h₃.2
· simp at h₄
apply h₄.2

theorem nodup_toList [EquivBEq α] [LawfulHashable α] (h : m.WF) : m.toList.Nodup := by
unfold List.Nodup
have distinct_keys := distinct_keys_toList h
suffices ∀ (l : List ((a :α) × β a)) (h₁ : List.Pairwise (fun a b => (a.1 == b.1) = false) l),
List.Pairwise (fun a b => a ≠ b) l from this m.toList distinct_keys
intro l
induction l with
| nil => simp
| cons hd tl ih =>
simp
intro h₁ h₂
apply And.intro ?_ (ih h₂)
intro a ha
false_or_by_contra
rename_i h'
specialize h₁ a ha
simp [h'] at h₁

theorem List.nodup_filter_of_nodup {α : Type u} {l : List α} (h : l.Nodup) {p : α → Bool} :
(l.filter p).Nodup := by
induction l with
| nil => simp
| cons hd tl ih =>
simp only [List.nodup_cons, List.filter] at h ⊢
split <;> simp [h, ih]

theorem partition_fst_equiv_filter [EquivBEq α] [LawfulHashable α]
{p : (a : α) → β a → Bool} (h : m.WF) :
(m.partition p).fst ~m m.filter p := by
apply Equiv.of_toList_perm
apply List.Perm.trans _ (toList_filter h).symm
apply Perm_helper
· apply nodup_toList
apply wf_partition_fst h
· apply List.nodup_filter_of_nodup (nodup_toList h)
· intro a
simp [mem_toList_fst_partition h]

theorem getKey?_partition_fst [LawfulBEq α]
{p : (a : α) → β a → Bool} (h : m.WF) {k : α} :
(m.partition p).1.getKey? k = (m.getKey? k).pfilter (fun x h' =>
p x (m.get x (mem_of_getKey?_eq_some h h'))) := by
rw [← getKey?_filter h]
apply Equiv.getKey?_eq (wf_partition_fst h) (WF.filter h)
apply partition_fst_equiv_filter h

theorem size_partition_fst_le_size [EquivBEq α] [LawfulHashable α]
{p : (a : α) → β a → Bool} (h : m.WF) :
(m.partition p).1.size ≤ m.size := by
rw [Equiv.size_eq (wf_partition_fst h) (WF.filter h) (partition_fst_equiv_filter h)]
apply size_filter_le_size h

attribute [simp] contains_eq_false_iff_not_mem
end Raw
end Std.DHashMap
44 changes: 44 additions & 0 deletions tests/bench/hashmap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,50 @@ def compareAnyBench : IO Unit := do

end anyTests

section partitionTests

def Std.HashSet.partition' {α : Type} [BEq α] [Hashable α] (m : Std.HashSet α) (p : α → Bool) := (m.filter p, m.filter (fun x => ! p x))

def benchNativePartition (size : Nat) (p : Nat) : IO Float := do
let mut set := Std.HashSet.emptyWithCapacity (α := Nat) size
let checks := size
timeNanos checks do
let (l, r) := set.partition (fun x => x % p == 0)
if l.size + r.size != set.size
then throw <| .userError "Fail"

def benchFilterPartition (size : Nat) (p : Nat) : IO Float := do
let mut set := Std.HashSet.emptyWithCapacity (α := Nat) size
let checks := size
timeNanos checks do
let (l, r) := set.partition' (fun x => x % p == 0)
if l.size + r.size != set.size
then throw <| .userError "Fail"

def evalPartition := do
let mut nativeBetter := 0
let mut filterBetter := 0

for size in [100, 1000, 10000, 100000, 1000000] do
for p in testPrimes do
let time1 ← benchNativePartition size p
let time2 ← benchFilterPartition size p

IO.println s!"Native scenario size: {size} prime : {p} time {time1}"
IO.println s!"Filter scenario size: {size} prime : {p} time {time2}"

if time1 ≤ time2 then
nativeBetter := nativeBetter + 1
else
filterBetter := filterBetter + 1

IO.println s!"Native function better: {nativeBetter}"
IO.println s!"Filter function better: {filterBetter}"

#eval evalPartition

end partitionTests

def main (args : List String) : IO Unit := do
let seed := args[0]!.toNat!.toUInt64
let size := args[1]!.toNat!
Expand Down
Loading