Skip to content

Commit 95e33d8

Browse files
authored
feat: add MonadControl lemmas for ReaderT, OptionT, StateT, and ExceptT (#11591)
This PR adds missing lemmas about how `ReaderT.run`, `OptionT.run`, `StateT.run`, and `ExceptT.run` interact with `MonadControl` operations. This also leaves some comments noting that the lemmas may look less general than expected; but this is because the instances are also not very general.
1 parent 351a941 commit 95e33d8

File tree

1 file changed

+82
-5
lines changed

1 file changed

+82
-5
lines changed

src/Init/Control/Lawful/Instances.lean

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ public section
1717

1818
open Function
1919

20+
@[simp, grind =] theorem monadMap_refl {m : Type _ → Type _} {α} (f : ∀ {α}, m α → m α) :
21+
monadMap @f = @f α := rfl
22+
2023
/-! # ExceptT -/
2124

2225
namespace ExceptT
@@ -57,6 +60,9 @@ theorem run_bind [Monad m] (x : ExceptT ε m α) (f : α → ExceptT ε m β)
5760
apply bind_congr
5861
intro a; cases a <;> simp [Except.map]
5962

63+
@[simp, grind =] theorem run_monadMap [MonadFunctorT n m] (f : {β : Type u} → n β → n β) (x : ExceptT ε m α)
64+
: (monadMap @f x : ExceptT ε m α).run = monadMap @f (x.run) := rfl
65+
6066
protected theorem seq_eq {α β ε : Type u} [Monad m] (mf : ExceptT ε m (α → β)) (x : ExceptT ε m α) : mf <*> x = mf >>= fun f => f <$> x :=
6167
rfl
6268

@@ -99,6 +105,22 @@ instance [Monad m] [LawfulMonad m] : LawfulMonad (ExceptT ε m) where
99105
simp only [ExceptT.instMonad, ExceptT.map, ExceptT.mk, throw, throwThe, MonadExceptOf.throw,
100106
pure_bind]
101107

108+
/-! Note that the `MonadControl` instance for `ExceptT` is not monad-generic. -/
109+
110+
@[simp] theorem run_restoreM [Monad m] (x : stM m (ExceptT ε m) α) :
111+
ExceptT.run (restoreM x) = pure x := rfl
112+
113+
@[simp] theorem run_liftWith [Monad m] (f : ({β : Type u} → ExceptT ε m β → m (stM m (ExceptT ε m) β)) → m α) :
114+
ExceptT.run (liftWith f) = Except.ok <$> (f fun x => x.run) :=
115+
rfl
116+
117+
@[simp] theorem run_controlAt [Monad m] [LawfulMonad m] (f : ({β : Type u} → ExceptT ε m β → m (stM m (ExceptT ε m) β)) → m (stM m (ExceptT ε m) α)) :
118+
ExceptT.run (controlAt m f) = f fun x => x.run := by
119+
simp [controlAt, run_bind, bind_map_left]
120+
121+
@[simp] theorem run_control [Monad m] [LawfulMonad m] (f : ({β : Type u} → ExceptT ε m β → m (stM m (ExceptT ε m) β)) → m (stM m (ExceptT ε m) α)) :
122+
ExceptT.run (control f) = f fun x => x.run := run_controlAt f
123+
102124
end ExceptT
103125

104126
/-! # Except -/
@@ -152,6 +174,9 @@ namespace OptionT
152174
apply bind_congr
153175
intro a; cases a <;> simp [OptionT.pure, OptionT.mk]
154176

177+
@[simp, grind =] theorem run_monadMap [MonadFunctorT n m] (f : {β : Type u} → n β → n β) (x : OptionT m α)
178+
: (monadMap @f x : OptionT m α).run = monadMap @f (x.run) := rfl
179+
155180
protected theorem seq_eq {α β : Type u} [Monad m] (mf : OptionT m (α → β)) (x : OptionT m α) : mf <*> x = mf >>= fun f => f <$> x :=
156181
rfl
157182

@@ -213,6 +238,24 @@ instance [Monad m] [LawfulMonad m] : LawfulMonad (OptionT m) where
213238
(x <|> y).run = Option.elimM x.run y.run (fun x => pure (some x)) :=
214239
bind_congr fun | some _ => by rfl | none => by rfl
215240

241+
/-! Note that the `MonadControl` instance for `OptionT` is not monad-generic. -/
242+
243+
@[simp] theorem run_restoreM [Monad m] (x : stM m (OptionT m) α) :
244+
OptionT.run (restoreM x) = pure x := rfl
245+
246+
@[simp] theorem run_liftWith [Monad m] [LawfulMonad m] (f : ({β : Type u} → OptionT m β → m (stM m (OptionT m) β)) → m α) :
247+
OptionT.run (liftWith f) = Option.some <$> (f fun x => x.run) := by
248+
dsimp [liftWith]
249+
rw [← bind_pure_comp]
250+
rfl
251+
252+
@[simp] theorem run_controlAt [Monad m] [LawfulMonad m] (f : ({β : Type u} → OptionT m β → m (stM m (OptionT m) β)) → m (stM m (OptionT m) α)) :
253+
OptionT.run (controlAt m f) = f fun x => x.run := by
254+
simp [controlAt, Option.elimM, Option.elim]
255+
256+
@[simp] theorem run_control [Monad m] [LawfulMonad m] (f : ({β : Type u} → OptionT m β → m (stM m (OptionT m) β)) → m (stM m (OptionT m) α)) :
257+
OptionT.run (control f) = f fun x => x.run := run_controlAt f
258+
216259
end OptionT
217260

218261
/-! # Option -/
@@ -284,6 +327,22 @@ instance [Monad m] [LawfulMonad m] : LawfulMonad (ReaderT ρ m) where
284327
pure_bind := by intros; apply ext; intros; simp
285328
bind_assoc := by intros; apply ext; intros; simp
286329

330+
/-! Note that the `MonadControl` instance for `ReaderT` is not monad-generic. -/
331+
332+
@[simp] theorem run_restoreM [Monad m] (x : stM m (ReaderT ρ m) α) (ctx : ρ) :
333+
ReaderT.run (restoreM x) ctx = pure x := rfl
334+
335+
@[simp] theorem run_liftWith [Monad m] (f : ({β : Type u} → ReaderT ρ m β → m (stM m (ReaderT ρ m) β)) → m α) (ctx : ρ) :
336+
ReaderT.run (liftWith f) ctx = (f fun x => x.run ctx) :=
337+
rfl
338+
339+
@[simp] theorem run_controlAt [Monad m] [LawfulMonad m] (f : ({β : Type u} → ReaderT ρ m β → m (stM m (ReaderT ρ m) β)) → m (stM m (ReaderT ρ m) α)) (ctx : ρ) :
340+
ReaderT.run (controlAt m f) ctx = f fun x => x.run ctx := by
341+
simp [controlAt]
342+
343+
@[simp] theorem run_control [Monad m] [LawfulMonad m] (f : ({β : Type u} → ReaderT ρ m β → m (stM m (ReaderT ρ m) β)) → m (stM m (ReaderT ρ m) α)) (ctx : ρ) :
344+
ReaderT.run (control f) ctx = f fun x => x.run ctx := run_controlAt f ctx
345+
287346
end ReaderT
288347

289348
/-! # StateRefT -/
@@ -307,11 +366,11 @@ namespace StateT
307366
@[simp, grind =] theorem run_pure [Monad m] (a : α) (s : σ) : (pure a : StateT σ m α).run s = pure (a, s) := rfl
308367

309368
@[simp, grind =] theorem run_bind [Monad m] (x : StateT σ m α) (f : α → StateT σ m β) (s : σ)
310-
: (x >>= f).run s = x.run s >>= λ p => (f p.1).run p.2 := by
311-
simp [bind, StateT.bind, run]
369+
: (x >>= f).run s = x.run s >>= λ p => (f p.1).run p.2 := rfl
312370

313371
@[simp, grind =] theorem run_map {α β σ : Type u} [Monad m] [LawfulMonad m] (f : α → β) (x : StateT σ m α) (s : σ) : (f <$> x).run s = (fun (p : α × σ) => (f p.1, p.2)) <$> x.run s := by
314-
simp [Functor.map, StateT.map, run, ←bind_pure_comp]
372+
rw [← bind_pure_comp (m := m)]
373+
rfl
315374

316375
@[simp, grind =] theorem run_get [Monad m] (s : σ) : (get : StateT σ m σ).run s = pure (s, s) := rfl
317376

@@ -320,13 +379,13 @@ namespace StateT
320379
@[simp, grind =] theorem run_modify [Monad m] (f : σ → σ) (s : σ) : (modify f : StateT σ m PUnit).run s = pure (⟨⟩, f s) := rfl
321380

322381
@[simp, grind =] theorem run_modifyGet [Monad m] (f : σ → α × σ) (s : σ) : (modifyGet f : StateT σ m α).run s = pure ((f s).1, (f s).2) := by
323-
simp [modifyGet, MonadStateOf.modifyGet, StateT.modifyGet, run]
382+
rfl
324383

325384
@[simp, grind =] theorem run_lift {α σ : Type u} [Monad m] (x : m α) (s : σ) : (StateT.lift x : StateT σ m α).run s = x >>= fun a => pure (a, s) := rfl
326385

327386
@[grind =]
328387
theorem run_bind_lift {α σ : Type u} [Monad m] [LawfulMonad m] (x : m α) (f : α → StateT σ m β) (s : σ) : (StateT.lift x >>= f).run s = x >>= fun a => (f a).run s := by
329-
simp [StateT.lift, StateT.run, bind, StateT.bind]
388+
simp
330389

331390
@[simp, grind =] theorem run_monadLift {α σ : Type u} [Monad m] [MonadLiftT n m] (x : n α) (s : σ) : (monadLift x : StateT σ m α).run s = (monadLift x : m α) >>= fun a => pure (a, s) := rfl
332391

@@ -366,6 +425,24 @@ instance [Monad m] [LawfulMonad m] : LawfulMonad (StateT σ m) where
366425
pure_bind := by intros; apply ext; intros; simp
367426
bind_assoc := by intros; apply ext; intros; simp
368427

428+
/-! Note that the `MonadControl` instance for `StateT` is not monad-generic. -/
429+
430+
@[simp] theorem run_restoreM [Monad m] [LawfulMonad m] (x : stM m (StateT σ m) α) (s : σ) :
431+
StateT.run (restoreM x) s = pure x := by
432+
simp [restoreM, MonadControl.restoreM]
433+
rfl
434+
435+
@[simp] theorem run_liftWith [Monad m] [LawfulMonad m] (f : ({β : Type u} → StateT σ m β → m (stM m (StateT σ m) β)) → m α) (s : σ) :
436+
StateT.run (liftWith f) s = ((·, s) <$> f fun x => x.run s) := by
437+
simp [liftWith, MonadControl.liftWith, Function.comp_def]
438+
439+
@[simp] theorem run_controlAt [Monad m] [LawfulMonad m] (f : ({β : Type u} → StateT σ m β → m (stM m (StateT σ m) β)) → m (stM m (StateT σ m) α)) (s : σ) :
440+
StateT.run (controlAt m f) s = f fun x => x.run s := by
441+
simp [controlAt]
442+
443+
@[simp] theorem run_control [Monad m] [LawfulMonad m] (f : ({β : Type u} → StateT σ m β → m (stM m (StateT σ m) β)) → m (stM m (StateT σ m) α)) (s : σ) :
444+
StateT.run (control f) s = f fun x => x.run s := run_controlAt f s
445+
369446
end StateT
370447

371448
/-! # EStateM -/

0 commit comments

Comments
 (0)