Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/Std/Data/DHashMap/Raw.lean
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,15 @@ section Unverified

/-! We currently do not provide lemmas for the functions below. -/

/-- Partition a hash map into two hash map based on a predicate. -/
@[inline] def partition [BEq α] [Hashable α] (f : (a : α) → β a → Bool)
(m : Raw α β) : Raw α β × Raw α β :=
m.fold (init := (∅, ∅)) fun ⟨l, r⟩ a b =>
if f a b then
(l.insert a b, r)
else
(l, r.insert a b)

/-- Returns a list of all values present in the hash map in some order. -/
@[inline] def values {β : Type v} (m : Raw α (fun _ => β)) : List β :=
Internal.foldRev (fun acc _ v => v :: acc) [] m
Expand Down
322 changes: 322 additions & 0 deletions src/Std/Data/DHashMap/RawLemmas.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5545,6 +5545,328 @@ end Const

end map

theorem size_partition [EquivBEq α] [LawfulHashable α] {p : (a : α) → β a → Bool} (h : m.WF) :
(m.partition p).1.size + (m.partition p).2.size = m.size := by
simp only [partition, fold_eq_foldl_toList h]
let f : Raw α β × Raw α β → List ((a : α) × β a) → Raw α β × Raw α β :=
fun pair l => List.foldl (fun a b => if p b.1 b.2 = true
then (a.1.insert b.1 b.2, a.2) else (a.1, a.2.insert b.1 b.2)) pair l
suffices ∀ (l : List ((a : α) × β a)) (m₁ m₂ : Raw α β) (h₁ : m₁.WF) (h₂ : m₂.WF)
(h₃ : ∀ x ∈ l, ¬ x.1 ∈ m₁ ∧ ¬ x.1 ∈ m₂)
(h₄ : l.Pairwise (fun a b => (a.1 == b.1) = false)),
(f (m₁, m₂) l).1.size + (f (m₁, m₂) l).2.size = l.length + m₁.size + m₂.size by
have := this m.toList ∅ ∅ WF.empty WF.empty (by simp) (distinct_keys_toList h)
simp only [length_toList h, size_empty, Nat.add_zero, f] at this
apply this
intro l
induction l with
| nil => simp [f]
| cons hd tl ih =>
intro m₁ m₂ h₁ h₂ h₃ h₄
simp only [List.foldl_cons, List.length_cons, f]
by_cases h : p hd.1 hd.2 = true
· simp only [h, ↓reduceIte]
rw [ih _ _ (WF.insert h₁) h₂]
· rw [size_insert h₁]
simp only [(h₃ hd (by simp)), ↓reduceIte, Nat.add_right_cancel_iff]
omega
· intro x hx
simp only [List.pairwise_cons] at h₄
simp [h₁, h₄.1 x hx, h₃ x (by simp [hx])]
· apply List.Pairwise.of_cons h₄
· simp only [h, Bool.false_eq_true, ↓reduceIte]
rw [ih _ _ h₁ (WF.insert h₂)]
· rw [size_insert h₂]
simp only [(h₃ hd (by simp)), ↓reduceIte]
omega
· intro x hx
simp only [List.pairwise_cons] at h₄
simp [h₂, h₄.1 x hx, h₃ x (by simp [hx])]
· apply List.Pairwise.of_cons h₄

theorem partition_not_fst_eq_partition_snd [EquivBEq α] [LawfulHashable α]
{p : (a : α) → β a → Bool} (h : m.WF) :
(m.partition (fun a b => ! p a b)).fst = (m.partition p).snd := by
simp only [partition, Bool.not_eq_eq_eq_not, Bool.not_true, fold_eq_foldl_toList h]
let f : Raw α β × Raw α β → List ((a : α) × β a) → Raw α β × Raw α β :=
fun pair l => List.foldl (fun a b => if p b.1 b.2 = true
then (a.1.insert b.1 b.2, a.2) else (a.1, a.2.insert b.1 b.2)) pair l
let f' : Raw α β × Raw α β → List ((a : α) × β a) → Raw α β × Raw α β :=
fun pair l => List.foldl (fun a b => if p b.1 b.2 = false
then (a.1.insert b.1 b.2, a.2) else (a.1, a.2.insert b.1 b.2)) pair l
suffices ∀ (l : List ((a : α) × β a)) (m₁ m₂ : Raw α β) (h₁ : m₁.WF) (h₂ : m₂.WF),
(f' (m₁, m₂) l).fst = (f (m₂, m₁) l).snd from this _ _ _ WF.empty WF.empty
intro l
induction l with
| nil => simp [f, f']
| cons hd tl ih =>
intro m₁ m₂ h₁ h₂
simp only [List.foldl_cons, f', f]
by_cases hhd : p hd.fst hd.snd = true
· simp only [hhd, Bool.true_eq_false, ↓reduceIte]
rw [ih _ _ h₁ (WF.insert h₂)]
· simp only [hhd, ↓reduceIte, Bool.false_eq_true]
rw [ih _ _ (WF.insert h₁) h₂]

theorem wf_partition_fst [EquivBEq α] [LawfulHashable α] (h : m.WF) {p : (a : α) → β a → Bool} :
(m.partition p).1.WF := by
simp only [partition, fold_eq_foldl_toList h]
let f : Raw α β × Raw α β → List ((a : α) × β a) → Raw α β × Raw α β :=
fun pair l => List.foldl (fun a b => if p b.1 b.2 = true
then (a.1.insert b.1 b.2, a.2) else (a.1, a.2.insert b.1 b.2)) pair l
suffices ∀ (l : List ((a : α) × β a)) (m₁ m₂ : Raw α β) (h₁ : m₁.WF) (h₂ : m₂.WF),
(f (m₁, m₂) l).1.WF from this _ _ _ WF.empty WF.empty
intro l
induction l with
| nil =>
simp only [List.foldl_nil, f]
intro m₁ _ h₁ _
apply h₁
| cons hd tl ih =>
intro m₁ m₂ h₁ h₂
simp only [List.foldl_cons, f]
by_cases hhd : p hd.fst hd.snd = true
· simp only [hhd, ↓reduceIte]
apply ih _ _ (WF.insert h₁) h₂
· simp only [hhd, Bool.false_eq_true, ↓reduceIte]
apply ih _ _ h₁ (WF.insert h₂)

section Perm

variable {α : Type u}

theorem List.Perm.of_nodup_of_nodup_of_forall_mem_iff_mem {α : Type u} (l₁ l₂ : List α)
(h₁ : l₁.Nodup) (h₂ : l₂.Nodup) (h₃ : ∀ (a : α), a ∈ l₁ ↔ a ∈ l₂) :
l₁.Perm l₂ := by
induction l₁ generalizing l₂ with
| nil =>
cases l₂ with
| nil => simp
| cons hd tl =>
specialize h₃ hd
simp at h₃
| cons hd tl ih =>
have hd_mem : hd ∈ l₂ := (h₃ hd).mp (by simp)
rw [List.mem_iff_append] at hd_mem
rcases hd_mem with ⟨pre, post, h⟩
rw [h]
apply List.Perm.trans ?_ List.perm_middle.symm
apply List.Perm.cons
have nodup_pre_post : (pre ++ post).Nodup := by
rw [h] at h₂
have := List.Perm.nodup List.perm_middle h₂
simp only [List.nodup_cons, List.mem_append, not_or] at this
apply this.2
apply ih
· simp only [List.nodup_cons] at h₁
apply h₁.2
· apply nodup_pre_post
· simp only [List.mem_append]
intro x
constructor
· intro hx
specialize h₃ x
simp only [List.mem_cons, hx, or_true, h, List.mem_append, true_iff] at h₃
simp only [h] at h₂
cases h₃ with
| inl h => apply Or.inl h
| inr h =>
cases h with
| inl h =>
rw [h] at hx
simp only [List.nodup_cons] at h₁
simp [hx] at h₁
| inr h =>
apply Or.inr h
· intro h'
simp only [h, List.nodup_append, List.nodup_cons, List.mem_cons, ne_eq,
forall_eq_or_imp] at h₂
rw [h] at h₃
specialize h₃ x
simp only [List.mem_cons, List.mem_append] at h₃
cases h' with
| inl h' =>
simp only [h', true_or, iff_true] at h₃
cases h₃ with
| inl h₃ =>
have := (h₂.2.2 x h').1
contradiction
| inr h₃ =>
apply h₃
| inr h' =>
simp only [h', or_true, iff_true] at h₃
cases h₃ with
| inl h₃ =>
have := (h₂.2.1.1)
simp [← h₃, h'] at this
| inr h₃ =>
apply h₃

end Perm

theorem mem_toList_insert [EquivBEq α] [LawfulHashable α] (h : m.WF) {k : α} {v : β k}
{x : (a : α) × β a} (h' : ¬ k ∈ m) :
x ∈ (m.insert k v).toList ↔ x ∈ m.toList ∨ x = ⟨k, v⟩ := by
have : containsₘ ⟨m, h.size_buckets_pos⟩ k = false := by
simp only [← contains_eq_containsₘ]
rw [mem_iff_contains] at h'
simp only [contains, h.size_buckets_pos, ↓reduceDIte, Bool.not_eq_true] at h'
apply h'
simp only [toList, insert, h.size_buckets_pos, ↓reduceDIte, insert_eq_insertₘ, insertₘ, this,
Bool.false_eq_true, ↓reduceIte, foldRev_cons, List.append_nil]
rw [List.Perm.mem_iff (toListModel_expandIfNecessary (consₘ ⟨m, h.size_buckets_pos⟩ k v))]
rw [List.Perm.mem_iff (toListModel_consₘ ⟨m, h.size_buckets_pos⟩ (Raw.WF.out h) k v)]
simp [Or.comm]

theorem mem_toList_fst_partition [EquivBEq α] [LawfulHashable α] (h : m.WF)
{p : (a : α) → β a → Bool} (x : (a : α) × β a) :
x ∈ (m.partition p).1.toList ↔ x ∈ m.toList ∧ p x.1 x.2 = true := by
simp only [partition, fold_eq_foldl_toList h]
let f : Raw α β × Raw α β → List ((a : α) × β a) → Raw α β × Raw α β :=
fun pair l => List.foldl (fun a b => if p b.1 b.2 = true
then (a.1.insert b.1 b.2, a.2) else (a.1, a.2.insert b.1 b.2)) pair l
suffices ∀ (l : List ((a : α) × β a)) (m₁ m₂ : Raw α β) (h₁ : m₁.WF) (h₂ : m₂.WF)
(h₃ : ∀ x ∈ l, ¬ x.1 ∈ m₁)
(h₄ : l.Pairwise (fun a b => (a.1 == b.1) = false)),
x ∈ (f (m₁, m₂) l).1.toList ↔ (x ∈ l ∧ p x.1 x.2 = true) ∨ x ∈ m₁.toList by
specialize this m.toList ∅ ∅ WF.empty WF.empty (by simp) (distinct_keys_toList h)
rw [this]
simp
intro l
induction l with
| nil =>
simp [f]
| cons hd tl ih =>
intro m₁ m₂ h₁ h₂ h₃ h₄
simp only [List.foldl_cons, List.mem_cons, f]
by_cases hhd : p hd.1 hd.2
· simp only [hhd, ↓reduceIte]
rw [ih _ _ (WF.insert h₁) h₂]
· rw [mem_toList_insert h₁ (h₃ hd (by simp))]
constructor
· intro h
cases h with
| inl h =>
apply Or.inl (And.intro (Or.inr h.1) h.2)
| inr h =>
cases h with
| inl h =>
apply Or.inr h
| inr h =>
rw [h]
simp [hhd]
· intro h
cases h with
| inl h =>
rcases h with ⟨h, h'⟩
cases h with
| inl h =>
rw [h]
simp [hhd]
| inr h =>
simp [h, h']
| inr h =>
simp [h]
· intro x hx
simp only [mem_insert h₁, not_or, Bool.not_eq_true]
constructor
· simp only [List.pairwise_cons] at h₄
apply h₄.1 x hx
· apply h₃ x (by simp [hx])
· simp only [List.pairwise_cons] at h₄
apply h₄.2
· simp only [hhd, Bool.false_eq_true, ↓reduceIte]
rw [ih _ _ h₁ (WF.insert h₂)]
· constructor
· intro h
cases h with
| inl h =>
apply Or.inl (And.intro (Or.inr h.1) h.2)
| inr h =>
apply Or.inr h
· intro h
cases h with
| inl h =>
rcases h with ⟨h, h'⟩
cases h with
| inl h =>
rw [h] at h'
contradiction
| inr h =>
simp [h, h']
| inr h =>
simp [h]
· simp only [List.mem_cons, forall_eq_or_imp] at h₃
apply h₃.2
· simp only [List.pairwise_cons] at h₄
apply h₄.2

theorem nodup_toList [EquivBEq α] [LawfulHashable α] (h : m.WF) : m.toList.Nodup := by
simp only [List.Nodup, ne_eq]
suffices ∀ (l : List ((a :α) × β a)) (h₁ : List.Pairwise (fun a b => (a.1 == b.1) = false) l),
List.Pairwise (fun a b => a ≠ b) l from this m.toList (distinct_keys_toList h)
intro l
induction l with
| nil => simp
| cons hd tl ih =>
simp
intro h₁ h₂
refine And.intro ?_ (ih h₂)
intro a ha
false_or_by_contra
rename_i h'
specialize h₁ a ha
simp [h'] at h₁

theorem List.nodup_filter_of_nodup {α : Type u} {l : List α} (h : l.Nodup) {p : α → Bool} :
(l.filter p).Nodup := by
induction l with
| nil => simp
| cons hd tl ih =>
simp only [List.nodup_cons, List.filter] at h ⊢
split <;> simp [h, ih]

theorem partition_fst_equiv_filter [EquivBEq α] [LawfulHashable α]
{p : (a : α) → β a → Bool} (h : m.WF) :
(m.partition p).fst ~m m.filter p := by
apply Equiv.of_toList_perm
apply List.Perm.trans _ (toList_filter h).symm
apply List.Perm.of_nodup_of_nodup_of_forall_mem_iff_mem
· apply nodup_toList
apply wf_partition_fst h
· apply List.nodup_filter_of_nodup (nodup_toList h)
· intro a
simp [mem_toList_fst_partition h]

theorem getKey?_partition_fst [LawfulBEq α]
{p : (a : α) → β a → Bool} (h : m.WF) {k : α} :
(m.partition p).1.getKey? k = (m.getKey? k).pfilter (fun x h' =>
p x (m.get x (mem_of_getKey?_eq_some h h'))) := by
rw [← getKey?_filter h]
apply Equiv.getKey?_eq (wf_partition_fst h) (WF.filter h)
apply partition_fst_equiv_filter h

theorem getKey?_partition_snd [LawfulBEq α]
{p : (a : α) → β a → Bool} (h : m.WF) {k : α} :
(m.partition p).2.getKey? k = (m.getKey? k).pfilter (fun x h' =>
! p x (m.get x (mem_of_getKey?_eq_some h h'))) := by
rw [← partition_not_fst_eq_partition_snd h, ← getKey?_filter h (f := fun a b => ! p a b)]
apply Equiv.getKey?_eq (wf_partition_fst h) (WF.filter h)
apply partition_fst_equiv_filter h

theorem size_partition_fst_le_size [EquivBEq α] [LawfulHashable α]
{p : (a : α) → β a → Bool} (h : m.WF) :
(m.partition p).1.size ≤ m.size := by
rw [Equiv.size_eq (wf_partition_fst h) (WF.filter h) (partition_fst_equiv_filter h)]
apply size_filter_le_size h

theorem size_partition_snd_le_size [EquivBEq α] [LawfulHashable α]
{p : (a : α) → β a → Bool} (h : m.WF) :
(m.partition p).2.size ≤ m.size := by
rw [← partition_not_fst_eq_partition_snd h]
rw [Equiv.size_eq (wf_partition_fst h) (WF.filter h) (partition_fst_equiv_filter h)]
apply size_filter_le_size h

attribute [simp] contains_eq_false_iff_not_mem
end Raw
end Std.DHashMap
44 changes: 44 additions & 0 deletions tests/bench/hashmap.lean
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,50 @@ def compareAnyBench : IO Unit := do

end anyTests

section partitionTests

def Std.HashSet.partition' {α : Type} [BEq α] [Hashable α] (m : Std.HashSet α) (p : α → Bool) := (m.filter p, m.filter (fun x => ! p x))

def benchNativePartition (size : Nat) (p : Nat) : IO Float := do
let mut set := Std.HashSet.emptyWithCapacity (α := Nat) size
let checks := size
timeNanos checks do
let (l, r) := set.partition (fun x => x % p == 0)
if l.size + r.size != set.size
then throw <| .userError "Fail"

def benchFilterPartition (size : Nat) (p : Nat) : IO Float := do
let mut set := Std.HashSet.emptyWithCapacity (α := Nat) size
let checks := size
timeNanos checks do
let (l, r) := set.partition' (fun x => x % p == 0)
if l.size + r.size != set.size
then throw <| .userError "Fail"

def evalPartition := do
let mut nativeBetter := 0
let mut filterBetter := 0

for size in [100, 1000, 10000, 100000, 1000000] do
for p in testPrimes do
let time1 ← benchNativePartition size p
let time2 ← benchFilterPartition size p

IO.println s!"Native scenario size: {size} prime : {p} time {time1}"
IO.println s!"Filter scenario size: {size} prime : {p} time {time2}"

if time1 ≤ time2 then
nativeBetter := nativeBetter + 1
else
filterBetter := filterBetter + 1

IO.println s!"Native function better: {nativeBetter}"
IO.println s!"Filter function better: {filterBetter}"

#eval evalPartition

end partitionTests

def main (args : List String) : IO Unit := do
let seed := args[0]!.toNat!.toUInt64
let size := args[1]!.toNat!
Expand Down
Loading