@@ -1692,6 +1692,41 @@ def getMax? (as : Array α) (lt : α → α → Bool) : Option α :=
16921692 else
16931693 none
16941694
1695+ /-! ### partitionM -/
1696+
1697+ /--
1698+ Returns a pair of arrays that together contain all the elements of `as`. The first array contains
1699+ those elements for which the monadic predicate `p` returns `true`, and the second contains those for
1700+ which `p` returns `false`. The array's elements are examined in order, from left to right.
1701+
1702+ This is a monadic version of `Array.partition`.
1703+
1704+ Example:
1705+ ```lean
1706+ def posOrNeg (x : Int) : Except String Bool :=
1707+ if x > 0 then pure true
1708+ else if x < 0 then pure false
1709+ else throw "Zero is not positive or negative"
1710+
1711+ #eval #[-1, 2, 3].partitionM posOrNeg
1712+ -- Except.ok (#[2, 3], #[-1])
1713+
1714+ #eval #[0, 2, 3].partitionM posOrNeg
1715+ -- Except.error "Zero is not positive or negative"
1716+ ```
1717+ -/
1718+
1719+ @[inline, specialize]
1720+ def partitionM {α : Type } [Monad m] (p : α → m Bool) (as : Array α) : m (Array α × Array α) := do
1721+ let mut bs := #[]
1722+ let mut cs := #[]
1723+ for a in as do
1724+ if (← p a) then
1725+ bs := bs.push a
1726+ else
1727+ cs := cs.push a
1728+ return (bs, cs)
1729+
16951730/--
16961731Returns a pair of arrays that together contain all the elements of `as`. The first array contains
16971732those elements for which `p` returns `true`, and the second contains those for which `p` returns
@@ -1716,6 +1751,46 @@ def partition (p : α → Bool) (as : Array α) : Array α × Array α := Id.run
17161751 cs := cs.push a
17171752 return (bs, cs)
17181753
1754+ /-! ### partitionMapM -/
1755+
1756+ /--
1757+ Applies a monadic function that returns a disjoint union to each element of an array,
1758+ collecting the `Sum.inl` and `Sum.inr` results into separate arrays.
1759+
1760+ Example:
1761+ ```lean
1762+ def f (x : Int) : Except String (Int ⊕ String) :=
1763+ if x % 2 = 0 then pure (Sum.inl x)
1764+ else pure (Sum.inr (toString x))
1765+
1766+ #eval #[0, 1, 2, 3].partitionMapM f
1767+ -- Except.ok (#[0, 2], #["1", "3"])
1768+ ```
1769+ -/
1770+ @[inline, specialize]
1771+ def partitionMapM [Monad m] (f : α → m (β ⊕ γ)) (as : Array α) : m (Array β × Array γ) := do
1772+ let mut bs := #[]
1773+ let mut cs := #[]
1774+ for a in as do
1775+ match ← f a with
1776+ | Sum.inl b => bs := bs.push b
1777+ | Sum.inr c => cs := cs.push c
1778+ return (bs, cs)
1779+
1780+ /-! ### partitionMap -/
1781+
1782+ /--
1783+ Applies a function that returns a disjoint union to each element of an array, collecting the `Sum.inl`
1784+ and `Sum.inr` results into separate arrays.
1785+
1786+ Examples:
1787+ * `#[0, 1, 2, 3].partitionMap (fun x => if x % 2 = 0 then .inl x else .inr x) = (#[0, 2], #[1, 3])`
1788+ * `#[0, 1, 2, 3].partitionMap (fun x => if x = 0 then .inl x else .inr x) = (#[0], #[1, 2, 3])`
1789+ -/
1790+ @[inline]
1791+ def partitionMap (f : α → β ⊕ γ) (as : Array α) : Array β × Array γ :=
1792+ Id.run <| partitionMapM (fun a => pure (f a)) as
1793+
17191794/--
17201795Removes all the elements that satisfy a predicate from the end of an array.
17211796
0 commit comments