Skip to content

Commit 838194c

Browse files
committed
feat: add Array.partitionM, Array.partitionMapM, Array.partitionMap, List.partitionMapM
1 parent c4e5f57 commit 838194c

File tree

2 files changed

+103
-10
lines changed

2 files changed

+103
-10
lines changed

src/Init/Data/Array/Basic.lean

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,41 @@ def getMax? (as : Array α) (lt : α → α → Bool) : Option α :=
16921692
else
16931693
none
16941694

1695+
/-! ### partitionM -/
1696+
1697+
/--
1698+
Returns a pair of arrays that together contain all the elements of `as`. The first array contains
1699+
those elements for which the monadic predicate `p` returns `true`, and the second contains those for
1700+
which `p` returns `false`. The array's elements are examined in order, from left to right.
1701+
1702+
This is a monadic version of `Array.partition`.
1703+
1704+
Example:
1705+
```lean
1706+
def posOrNeg (x : Int) : Except String Bool :=
1707+
if x > 0 then pure true
1708+
else if x < 0 then pure false
1709+
else throw "Zero is not positive or negative"
1710+
1711+
#eval #[-1, 2, 3].partitionM posOrNeg
1712+
-- Except.ok (#[2, 3], #[-1])
1713+
1714+
#eval #[0, 2, 3].partitionM posOrNeg
1715+
-- Except.error "Zero is not positive or negative"
1716+
```
1717+
-/
1718+
1719+
@[inline, specialize]
1720+
def partitionM {α : Type} [Monad m] (p : α → m Bool) (as : Array α) : m (Array α × Array α) := do
1721+
let mut bs := #[]
1722+
let mut cs := #[]
1723+
for a in as do
1724+
if (← p a) then
1725+
bs := bs.push a
1726+
else
1727+
cs := cs.push a
1728+
return (bs, cs)
1729+
16951730
/--
16961731
Returns a pair of arrays that together contain all the elements of `as`. The first array contains
16971732
those elements for which `p` returns `true`, and the second contains those for which `p` returns
@@ -1716,6 +1751,46 @@ def partition (p : α → Bool) (as : Array α) : Array α × Array α := Id.run
17161751
cs := cs.push a
17171752
return (bs, cs)
17181753

1754+
/-! ### partitionMapM -/
1755+
1756+
/--
1757+
Applies a monadic function that returns a disjoint union to each element of an array,
1758+
collecting the `Sum.inl` and `Sum.inr` results into separate arrays.
1759+
1760+
Example:
1761+
```lean
1762+
def f (x : Int) : Except String (Int ⊕ String) :=
1763+
if x % 2 = 0 then pure (Sum.inl x)
1764+
else pure (Sum.inr (toString x))
1765+
1766+
#eval #[0, 1, 2, 3].partitionMapM f
1767+
-- Except.ok (#[0, 2], #["1", "3"])
1768+
```
1769+
-/
1770+
@[inline, specialize]
1771+
def partitionMapM [Monad m] (f : α → m (β ⊕ γ)) (as : Array α) : m (Array β × Array γ) := do
1772+
let mut bs := #[]
1773+
let mut cs := #[]
1774+
for a in as do
1775+
match ← f a with
1776+
| Sum.inl b => bs := bs.push b
1777+
| Sum.inr c => cs := cs.push c
1778+
return (bs, cs)
1779+
1780+
/-! ### partitionMap -/
1781+
1782+
/--
1783+
Applies a function that returns a disjoint union to each element of an array, collecting the `Sum.inl`
1784+
and `Sum.inr` results into separate arrays.
1785+
1786+
Examples:
1787+
* `#[0, 1, 2, 3].partitionMap (fun x => if x % 2 = 0 then .inl x else .inr x) = (#[0, 2], #[1, 3])`
1788+
* `#[0, 1, 2, 3].partitionMap (fun x => if x = 0 then .inl x else .inr x) = (#[0], #[1, 2, 3])`
1789+
-/
1790+
@[inline]
1791+
def partitionMap (f : α → β ⊕ γ) (as : Array α) : Array β × Array γ :=
1792+
Id.run <| partitionMapM (fun a => pure (f a)) as
1793+
17191794
/--
17201795
Removes all the elements that satisfy a predicate from the end of an array.
17211796

src/Init/Data/List/BasicAux.lean

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,33 @@ where
198198
else
199199
go xs acc₁ (acc₂.push x)
200200

201+
/-! ### partitionMapM -/
202+
203+
/--
204+
Applies a monadic function that returns a disjoint union to each element of a list,
205+
collecting the `Sum.inl` and `Sum.inr` results into separate lists.
206+
207+
Example:
208+
```lean
209+
def f (x : Int) : Except String (Int ⊕ String) :=
210+
if x % 2 = 0 then pure (Sum.inl x)
211+
else pure (Sum.inr (toString x))
212+
213+
#eval [0, 1, 2, 3].partitionMapM f
214+
-- Except.ok ([0, 2], ["1", "3"])
215+
```
216+
-/
217+
@[inline] def partitionMapM [Monad m] (f : α → m (β ⊕ γ)) (l : List α) : m (List β × List γ) := go l #[] #[] where
218+
/-- Auxiliary for `partitionMapM`:
219+
`partitionMapM.go f l acc₁ acc₂` returns `(acc₁.toList ++ left, acc₂.toList ++ right)`
220+
if `partitionMapM f l = (left, right)`. -/
221+
@[specialize] go : List α → Array β → Array γ → m (List β × List γ)
222+
| [], acc₁, acc₂ => pure (acc₁.toList, acc₂.toList)
223+
| x :: xs, acc₁, acc₂ => do
224+
match ← f x with
225+
| .inl a => go xs (acc₁.push a) acc₂
226+
| .inr b => go xs acc₁ (acc₂.push b)
227+
201228
/-! ### partitionMap -/
202229

203230
/--
@@ -208,16 +235,7 @@ Examples:
208235
* `[0, 1, 2, 3].partitionMap (fun x => if x % 2 = 0 then .inl x else .inr x) = ([0, 2], [1, 3])`
209236
* `[0, 1, 2, 3].partitionMap (fun x => if x = 0 then .inl x else .inr x) = ([0], [1, 2, 3])`
210237
-/
211-
@[inline] def partitionMap (f : α → β ⊕ γ) (l : List α) : List β × List γ := go l #[] #[] where
212-
/-- Auxiliary for `partitionMap`:
213-
`partitionMap.go f l acc₁ acc₂ = (acc₁.toList ++ left, acc₂.toList ++ right)`
214-
if `partitionMap f l = (left, right)`. -/
215-
@[specialize] go : List α → Array β → Array γ → List β × List γ
216-
| [], acc₁, acc₂ => (acc₁.toList, acc₂.toList)
217-
| x :: xs, acc₁, acc₂ =>
218-
match f x with
219-
| .inl a => go xs (acc₁.push a) acc₂
220-
| .inr b => go xs acc₁ (acc₂.push b)
238+
@[inline] def partitionMap (f : α → β ⊕ γ) (l : List α) : List β × List γ := Id.run <| partitionMapM (fun a => pure (f a)) l
221239

222240
/-! ### mapMono
223241

0 commit comments

Comments
 (0)