Skip to content

Commit 1c79aac

Browse files
committed
feat: LawfulMonad and WPMonad instances for Option and OptionT
1 parent ac0b829 commit 1c79aac

File tree

6 files changed

+172
-26
lines changed

6 files changed

+172
-26
lines changed

src/Init/Control/Lawful/Instances.lean

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ prelude
99
public import Init.Control.Lawful.Basic
1010
public import Init.Control.Except
1111
import all Init.Control.Except
12+
public import Init.Control.Option
13+
import all Init.Control.Option
1214
public import Init.Control.State
1315
import all Init.Control.State
1416
public import Init.Control.StateRef
@@ -110,6 +112,118 @@ instance : LawfulMonad (Except ε) := LawfulMonad.mk'
110112
instance : LawfulApplicative (Except ε) := inferInstance
111113
instance : LawfulFunctor (Except ε) := inferInstance
112114

115+
/-! # OptionT -/
116+
117+
namespace OptionT
118+
119+
@[ext] theorem ext {x y : OptionT m α} (h : x.run = y.run) : x = y := by
120+
simp [run] at h
121+
assumption
122+
123+
@[simp, grind =] theorem run_pure [Monad m] (x : α) : run (pure x : OptionT m α) = pure (some x) := by
124+
simp [run, pure, OptionT.pure, OptionT.mk]
125+
126+
@[simp, grind =] theorem run_lift [Monad.{u, v} m] (x : m α) : run (OptionT.lift x : OptionT m α) = (return some (← x) : m (Option α)) := by
127+
simp [run, OptionT.lift, OptionT.mk]
128+
129+
@[simp, grind =] theorem run_throw [Monad m] : run (throw e : OptionT m β) = pure none := by
130+
simp [run, throw, throwThe, MonadExceptOf.throw, OptionT.fail, OptionT.mk]
131+
132+
@[simp, grind =] theorem run_bind_lift [Monad m] [LawfulMonad m] (x : m α) (f : α → OptionT m β) : run (OptionT.lift x >>= f : OptionT m β) = x >>= fun a => run (f a) := by
133+
simp [OptionT.run, OptionT.lift, bind, OptionT.bind, OptionT.mk]
134+
135+
@[simp, grind =] theorem bind_throw [Monad m] [LawfulMonad m] (f : α → OptionT m β) : (throw e >>= f) = throw e := by
136+
simp [throw, throwThe, MonadExceptOf.throw, bind, OptionT.bind, OptionT.mk, OptionT.fail]
137+
138+
@[simp, grind =] theorem run_bind (f : α → OptionT m β) [Monad m] :
139+
(x >>= f).run = Option.elimM x.run (pure none) (fun x => (f x).run) := by
140+
change x.run >>= _ = _
141+
simp [Option.elimM]
142+
exact bind_congr fun |some _ => rfl | none => rfl
143+
144+
@[simp, grind =] theorem lift_pure [Monad m] [LawfulMonad m] {α : Type u} (a : α) : OptionT.lift (pure a : m α) = pure a := by
145+
simp only [OptionT.lift, OptionT.mk, bind_pure_comp, map_pure, pure, OptionT.pure]
146+
147+
@[simp, grind =] theorem run_map [Monad m] [LawfulMonad m] (f : α → β) (x : OptionT m α)
148+
: (f <$> x).run = Option.map f <$> x.run := by
149+
simp [Functor.map, Option.map, ←bind_pure_comp]
150+
apply bind_congr
151+
intro a; cases a <;> simp [OptionT.pure, OptionT.mk]
152+
153+
protected theorem seq_eq {α β : Type u} [Monad m] (mf : OptionT m (α → β)) (x : OptionT m α) : mf <*> x = mf >>= fun f => f <$> x :=
154+
rfl
155+
156+
protected theorem bind_pure_comp [Monad m] (f : α → β) (x : OptionT m α) : x >>= pure ∘ f = f <$> x := by
157+
intros; rfl
158+
159+
protected theorem seqLeft_eq {α β : Type u} {m : Type u → Type v} [Monad m] [LawfulMonad m] (x : OptionT m α) (y : OptionT m β) : x <* y = const β <$> x <*> y := by
160+
change (x >>= fun a => y >>= fun _ => pure a) = (const (α := α) β <$> x) >>= fun f => f <$> y
161+
rw [← OptionT.bind_pure_comp]
162+
apply ext
163+
simp [Option.elimM, Option.elim]
164+
apply bind_congr
165+
intro
166+
| none => simp
167+
| some _ =>
168+
simp [←bind_pure_comp]; apply bind_congr; intro b;
169+
cases b <;> simp [const]
170+
171+
protected theorem seqRight_eq [Monad m] [LawfulMonad m] (x : OptionT m α) (y : OptionT m β) : x *> y = const α id <$> x <*> y := by
172+
change (x >>= fun _ => y) = (const α id <$> x) >>= fun f => f <$> y
173+
rw [← OptionT.bind_pure_comp]
174+
apply ext
175+
simp [Option.elimM, Option.elim]
176+
apply bind_congr
177+
intro a; cases a <;> simp
178+
179+
instance [Monad m] [LawfulMonad m] : LawfulMonad (OptionT m) where
180+
id_map := by intros; apply ext; simp
181+
map_const := by intros; rfl
182+
seqLeft_eq := OptionT.seqLeft_eq
183+
seqRight_eq := OptionT.seqRight_eq
184+
pure_seq := by intros; apply ext; simp [OptionT.seq_eq, Option.elimM, Option.elim]
185+
bind_pure_comp := OptionT.bind_pure_comp
186+
bind_map := by intros; rfl
187+
pure_bind := by intros; apply ext; simp [Option.elimM, Option.elim]
188+
bind_assoc := by intros; apply ext; simp [Option.elimM, Option.elim]; apply bind_congr; intro a; cases a <;> simp
189+
190+
@[simp] theorem run_seq [Monad m] [LawfulMonad m] (f : OptionT m (α → β)) (x : OptionT m α) :
191+
(f <*> x).run = Option.elimM f.run (pure none) (fun f => Option.map f <$> x.run) := by
192+
simp [seq_eq_bind, Option.elimM, Option.elim]
193+
194+
@[simp] theorem run_seqLeft [Monad m] [LawfulMonad m] (x : OptionT m α) (y : OptionT m β) :
195+
(x <* y).run = Option.elimM x.run (pure none)
196+
(fun x => Option.map (Function.const β x) <$> y.run) := by
197+
simp [seqLeft_eq, seq_eq_bind, Option.elimM, OptionT.run_bind]
198+
199+
@[simp] theorem run_seqRight [Monad m] [LawfulMonad m] (x : OptionT m α) (y : OptionT m β) :
200+
(x *> y).run = Option.elimM x.run (pure none) (Function.const α y.run) := by
201+
simp only [seqRight_eq, run_seq, Option.elimM, run_map, Option.elim, bind_map_left]
202+
refine bind_congr (fun | some _ => by simp | none => by simp)
203+
204+
@[simp, grind =] theorem run_failure [Monad m] : (failure : OptionT m α).run = pure none := by rfl
205+
206+
@[simp] theorem map_failure [Monad m] [LawfulMonad m] {α β : Type _} (f : α → β) :
207+
f <$> (failure : OptionT m α) = (failure : OptionT m β) := by
208+
simp [OptionT.mk, Functor.map, Alternative.failure, OptionT.fail, OptionT.bind]
209+
210+
@[simp] theorem run_orElse [Monad m] (x : OptionT m α) (y : OptionT m α) :
211+
(x <|> y).run = Option.elimM x.run y.run (fun x => pure (some x)) :=
212+
bind_congr fun | some _ => by rfl | none => by rfl
213+
214+
end OptionT
215+
216+
/-! # Option -/
217+
218+
instance : LawfulMonad Option := LawfulMonad.mk'
219+
(id_map := fun x => by cases x <;> rfl)
220+
(pure_bind := fun _ _ => by rfl)
221+
(bind_assoc := fun a _ _ => by cases a <;> rfl)
222+
(bind_pure_comp := bind_pure_comp)
223+
224+
instance : LawfulApplicative Option := inferInstance
225+
instance : LawfulFunctor Option := inferInstance
226+
113227
/-! # ReaderT -/
114228

115229
namespace ReaderT

src/Init/Control/Lawful/MonadLift/Instances.lean

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,6 @@ namespace OptionT
6464

6565
variable [Monad m] [LawfulMonad m]
6666

67-
@[simp]
68-
theorem lift_pure {α : Type u} (a : α) : OptionT.lift (pure a : m α) = pure a := by
69-
simp only [OptionT.lift, OptionT.mk, bind_pure_comp, map_pure, pure, OptionT.pure]
70-
7167
@[simp]
7268
theorem lift_bind {α β : Type u} (ma : m α) (f : α → m β) :
7369
OptionT.lift (ma >>= f) = OptionT.lift ma >>= (fun a => OptionT.lift (f a)) := by

src/Init/Control/Option.lean

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ variable {m : Type u → Type v} [Monad m] {α β : Type u}
3939
Converts an action that returns an `Option` into one that might fail, with `none` indicating
4040
failure.
4141
-/
42+
@[always_inline, inline, expose]
4243
protected def mk (x : m (Option α)) : OptionT m α :=
4344
x
4445

4546
/--
4647
Sequences two potentially-failing actions. The second action is run only if the first succeeds.
4748
-/
48-
@[always_inline, inline]
49+
@[always_inline, inline, expose]
4950
protected def bind (x : OptionT m α) (f : α → OptionT m β) : OptionT m β := OptionT.mk do
5051
match (← x) with
5152
| some a => f a
@@ -54,7 +55,7 @@ protected def bind (x : OptionT m α) (f : α → OptionT m β) : OptionT m β :
5455
/--
5556
Succeeds with the provided value.
5657
-/
57-
@[always_inline, inline]
58+
@[always_inline, inline, expose]
5859
protected def pure (a : α) : OptionT m α := OptionT.mk do
5960
pure (some a)
6061

src/Std/Do/PredTrans.lean

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -137,28 +137,37 @@ instance instLawfulMonad : LawfulMonad (PredTrans ps) :=
137137

138138
-- The interpretation of `StateT σ (PredTrans ps) α` into `PredTrans (.arg σ ps) α`.
139139
-- Think: modifyGetM
140-
def pushArg {σ : Type u} (x : StateT σ (PredTrans ps) α) : PredTrans (.arg σ ps) α :=
141-
{ apply := fun Q s => (x s).apply (fun (a, s) => Q.1 a s, Q.2),
142-
conjunctive := by
143-
intro Q₁ Q₂
144-
apply SPred.bientails.of_eq
145-
ext s
146-
dsimp only [SPred.and_cons, ExceptConds.and]
147-
rw [← ((x s).conjunctive _ _).to_eq]
148-
}
140+
def pushArg {σ : Type u} (x : StateT σ (PredTrans ps) α) : PredTrans (.arg σ ps) α where
141+
apply := fun Q s => (x s).apply (fun (a, s) => Q.1 a s, Q.2)
142+
conjunctive := by
143+
intro Q₁ Q₂
144+
apply SPred.bientails.of_eq
145+
ext s
146+
dsimp only [SPred.and_cons, ExceptConds.and]
147+
rw [← ((x s).conjunctive _ _).to_eq]
149148

150149
-- The interpretation of `ExceptT ε (PredTrans ps) α` into `PredTrans (.except ε ps) α`
151-
def pushExcept {ps : PostShape} {α ε} (x : ExceptT ε (PredTrans ps) α) : PredTrans (.except ε ps) α :=
152-
{ apply Q := x.apply (fun | .ok a => Q.1 a | .error e => Q.2.1 e, Q.2.2),
153-
conjunctive := by
154-
intro Q₁ Q₂
155-
apply SPred.bientails.of_eq
156-
dsimp
157-
rw[← (x.conjunctive _ _).to_eq]
158-
congr
159-
ext x
160-
cases x <;> simp
161-
}
150+
def pushExcept {ps : PostShape} {α ε} (x : ExceptT ε (PredTrans ps) α) : PredTrans (.except ε ps) α where
151+
apply Q := x.apply (fun | .ok a => Q.1 a | .error e => Q.2.1 e, Q.2.2)
152+
conjunctive := by
153+
intro Q₁ Q₂
154+
apply SPred.bientails.of_eq
155+
dsimp
156+
rw[← (x.conjunctive _ _).to_eq]
157+
congr
158+
ext x
159+
cases x <;> simp
160+
161+
def pushOption {ps : PostShape} {α} (x : OptionT (PredTrans ps) α) : PredTrans (.except PUnit ps) α where
162+
apply Q := x.apply (fun | .some a => Q.1 a | .none => Q.2.1 ⟨⟩, Q.2.2)
163+
conjunctive := by
164+
intro Q₁ Q₂
165+
apply SPred.bientails.of_eq
166+
dsimp
167+
rw[← (x.conjunctive _ _).to_eq]
168+
congr
169+
ext x
170+
cases x <;> simp
162171

163172
@[simp]
164173
def pushArg_apply {ps} {α σ : Type u} {Q : PostCond α (.arg σ ps)} (f : σ → PredTrans ps (α × σ)) :
@@ -168,6 +177,10 @@ def pushArg_apply {ps} {α σ : Type u} {Q : PostCond α (.arg σ ps)} (f : σ
168177
def pushExcept_apply {ps} {α ε : Type u} {Q : PostCond α (.except ε ps)} (x : PredTrans ps (Except ε α)) :
169178
(pushExcept x).apply Q = x.apply (fun | .ok a => Q.1 a | .error e => Q.2.1 e, Q.2.2) := rfl
170179

180+
@[simp]
181+
def pushOption_apply {ps} {α : Type u} {Q : PostCond α (.except PUnit ps)} (x : PredTrans ps (Option α)) :
182+
(pushOption x).apply Q = x.apply (fun | .some a => Q.1 a | .none => Q.2.1 ⟨⟩, Q.2.2) := rfl
183+
171184
def dite_apply {ps} {Q : PostCond α ps} (c : Prop) [Decidable c] (t : c → PredTrans ps α) (e : ¬ c → PredTrans ps α) :
172185
(if h : c then t h else e h).apply Q = if h : c then (t h).apply Q else (e h).apply Q := by split <;> rfl
173186

src/Std/Do/WP/Basic.lean

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ protected meta def unexpandWP : Lean.PrettyPrinter.Unexpander
6767
| `(wp $e) => `(wp⟦$e⟧)
6868
| _ => throw ()
6969
| _ => throw ()
70+
7071
instance Id.instWP : WP Id .pure where
7172
wp x := PredTrans.pure x.run
7273

@@ -79,6 +80,9 @@ instance ReaderT.instWP [WP m ps] : WP (ReaderT ρ m) (.arg ρ ps) where
7980
instance ExceptT.instWP [WP m ps] : WP (ExceptT ε m) (.except ε ps) where
8081
wp x := PredTrans.pushExcept (wp x)
8182

83+
instance OptionT.instWP [WP m ps] : WP (OptionT m) (.except PUnit ps) where
84+
wp x := PredTrans.pushOption (wp x)
85+
8286
instance EStateM.instWP : WP (EStateM ε σ) (.except ε (.arg σ .pure)) where
8387
wp x := -- Could define as PredTrans.mkExcept (PredTrans.modifyGetM (fun s => pure (EStateM.Result.toExceptState (x s))))
8488
{ apply := fun Q s => match x s with
@@ -98,6 +102,8 @@ instance Reader.instWP : WP (ReaderM ρ) (.arg ρ .pure) :=
98102
inferInstanceAs (WP (ReaderT ρ Id) (.arg ρ .pure))
99103
instance Except.instWP : WP (Except ε) (.except ε .pure) :=
100104
inferInstanceAs (WP (ExceptT ε Id) (.except ε .pure))
105+
instance Option.instWP : WP Option (.except PUnit .pure) :=
106+
inferInstanceAs (WP (OptionT Id) (.except PUnit .pure))
101107

102108
/--
103109
Adequacy lemma for `Id.run`.

src/Std/Do/WP/Monad.lean

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,18 @@ instance ExceptT.instWPMonad [Monad m] [WPMonad m ps] : WPMonad (ExceptT ε m) (
7272
case error a => simp [wp_pure]
7373
case ok a => rfl
7474

75+
instance OptionT.instWPMonad [Monad m] [WPMonad m ps] : WPMonad (OptionT m) (.except PUnit ps) where
76+
wp_pure a := by ext; simp only [wp, pure, OptionT.pure, OptionT.mk, WPMonad.wp_pure,
77+
PredTrans.pure, PredTrans.pushOption_apply]
78+
wp_bind x f := by
79+
ext Q
80+
simp only [wp, bind, OptionT.bind, OptionT.mk, WPMonad.wp_bind, PredTrans.bind, PredTrans.pushOption_apply]
81+
congr
82+
ext b
83+
cases b
84+
case none => simp [wp_pure]
85+
case some a => rfl
86+
7587
instance EStateM.instWPMonad : WPMonad (EStateM ε σ) (.except ε (.arg σ .pure)) where
7688
wp_pure a := by simp only [wp, pure, EStateM.pure, PredTrans.pure]
7789
wp_bind x f := by
@@ -84,6 +96,10 @@ instance Except.instWPMonad : WPMonad (Except ε) (.except ε .pure) where
8496
wp_pure a := rfl
8597
wp_bind x f := by cases x <;> rfl
8698

99+
instance Option.instWPMonad : WPMonad Option (.except PUnit .pure) where
100+
wp_pure a := rfl
101+
wp_bind x f := by cases x <;> rfl
102+
87103
instance State.instWPMonad : WPMonad (StateM σ) (.arg σ .pure) :=
88104
inferInstanceAs (WPMonad (StateT σ Id) (.arg σ .pure))
89105
instance Reader.instWPMonad : WPMonad (ReaderM ρ) (.arg ρ .pure) :=

0 commit comments

Comments
 (0)